diff --git a/.gitignore b/.gitignore index 3bea14c..12d734c 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ GoNavi-Wails.exe .ace-tool/ .claude/ tmpclaude-* + +CLAUDE.md +**/CLAUDE.md diff --git a/cmd/optional-driver-agent/provider_mongodb_v1.go b/cmd/optional-driver-agent/provider_mongodb_v1.go new file mode 100644 index 0000000..9a81ce9 --- /dev/null +++ b/cmd/optional-driver-agent/provider_mongodb_v1.go @@ -0,0 +1,12 @@ +//go:build gonavi_mongodb_driver_v1 + +package main + +import "GoNavi-Wails/internal/db" + +func init() { + agentDriverType = "mongodb" + agentDatabaseFactory = func() db.Database { + return &db.MongoDBV1{} + } +} diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index ba93b57..be49c41 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,8 +1,8 @@ import React, { useState, useEffect } from 'react'; import { Layout, Button, ConfigProvider, theme, Dropdown, MenuProps, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select } from 'antd'; import zhCN from 'antd/locale/zh_CN'; -import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined } from '@ant-design/icons'; -import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowToggleMaximise } from '../wailsjs/runtime'; +import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined } from '@ant-design/icons'; +import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowGetSize, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowSetSize, WindowToggleMaximise } from '../wailsjs/runtime'; import Sidebar from './components/Sidebar'; import TabManager from './components/TabManager'; import ConnectionModal from './components/ConnectionModal'; @@ -12,6 +12,17 @@ import LogPanel from './components/LogPanel'; import { useStore } from './store'; import { SavedConnection } from './types'; import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform } from './utils/appearance'; +import { + SHORTCUT_ACTION_META, + SHORTCUT_ACTION_ORDER, + ShortcutAction, + eventToShortcut, + getShortcutDisplay, + hasModifierKey, + isEditableElement, + isShortcutMatch, + normalizeShortcutCombo, +} from './utils/shortcuts'; import { ConfigureGlobalProxy, SetWindowTranslucency } from '../wailsjs/go/app/App'; import './App.css'; @@ -53,6 +64,9 @@ function App() { const setStartupFullscreen = useStore(state => state.setStartupFullscreen); const globalProxy = useStore(state => state.globalProxy); const setGlobalProxy = useStore(state => state.setGlobalProxy); + const shortcutOptions = useStore(state => state.shortcutOptions); + const updateShortcut = useStore(state => state.updateShortcut); + const resetShortcutOptions = useStore(state => state.resetShortcutOptions); const darkMode = themeMode === 'dark'; const effectiveUiScale = Math.min(MAX_UI_SCALE, Math.max(MIN_UI_SCALE, Number(uiScale) || DEFAULT_UI_SCALE)); const effectiveFontSize = Math.min(MAX_FONT_SIZE, Math.max(MIN_FONT_SIZE, Math.round(Number(fontSize) || DEFAULT_FONT_SIZE))); @@ -260,6 +274,80 @@ function App() { }; }, []); + useEffect(() => { + if (!isWindowsPlatform()) { + return; + } + + let cancelled = false; + let inFlight = false; + let lastRatio = Number(window.devicePixelRatio) || 1; + let lastFixAt = 0; + + const wait = (ms: number) => new Promise((resolve) => window.setTimeout(resolve, ms)); + + const fixWindowScaleIfNeeded = async () => { + if (cancelled || inFlight) return; + const now = Date.now(); + if (now - lastFixAt < 700) return; + inFlight = true; + try { + const [isFullscreen, isMaximised] = await Promise.all([ + WindowIsFullscreen().catch(() => false), + WindowIsMaximised().catch(() => false), + ]); + + // 避免在全屏/最大化状态下强制改尺寸;这两种状态通常能自行保持 DPI 同步。 + if (isFullscreen || isMaximised) { + window.dispatchEvent(new Event('resize')); + lastFixAt = Date.now(); + return; + } + + const size = await WindowGetSize().catch(() => null); + const width = Math.trunc(Number(size?.w || 0)); + const height = Math.trunc(Number(size?.h || 0)); + if (width <= 0 || height <= 0) { + window.dispatchEvent(new Event('resize')); + lastFixAt = Date.now(); + return; + } + + const nudgedWidth = width > 480 ? width - 1 : width + 1; + WindowSetSize(nudgedWidth, height); + await wait(28); + WindowSetSize(width, height); + window.dispatchEvent(new Event('resize')); + lastFixAt = Date.now(); + } finally { + inFlight = false; + } + }; + + const checkDevicePixelRatio = () => { + if (cancelled) return; + const currentRatio = Number(window.devicePixelRatio) || 1; + if (Math.abs(currentRatio - lastRatio) < 0.02) { + return; + } + lastRatio = currentRatio; + void fixWindowScaleIfNeeded(); + }; + + const pollTimer = window.setInterval(checkDevicePixelRatio, 900); + window.addEventListener('resize', checkDevicePixelRatio); + window.addEventListener('focus', checkDevicePixelRatio); + document.addEventListener('visibilitychange', checkDevicePixelRatio); + + return () => { + cancelled = true; + window.clearInterval(pollTimer); + window.removeEventListener('resize', checkDevicePixelRatio); + window.removeEventListener('focus', checkDevicePixelRatio); + document.removeEventListener('visibilitychange', checkDevicePixelRatio); + }; + }, []); + // Background Helper const getBg = (darkHex: string) => { if (!darkMode) return `rgba(255, 255, 255, ${effectiveOpacity})`; // Light mode usually white @@ -299,7 +387,7 @@ function App() { const [isAboutOpen, setIsAboutOpen] = useState(false); const isAboutOpenRef = React.useRef(false); const [aboutLoading, setAboutLoading] = useState(false); - const [aboutInfo, setAboutInfo] = useState<{ version: string; author: string; buildTime?: string; repoUrl?: string; issueUrl?: string; releaseUrl?: string } | null>(null); + const [aboutInfo, setAboutInfo] = useState<{ version: string; author: string; buildTime?: string; repoUrl?: string; issueUrl?: string; releaseUrl?: string; communityUrl?: string } | null>(null); const [aboutUpdateStatus, setAboutUpdateStatus] = useState(''); const [lastUpdateInfo, setLastUpdateInfo] = useState(null); const [updateDownloadProgress, setUpdateDownloadProgress] = useState<{ @@ -666,7 +754,7 @@ function App() { void message.warning("没有连接可导出"); return; } - const res = await (window as any).go.app.App.ExportData(connections, [], "connections", "json"); + const res = await (window as any).go.app.App.ExportData(connections, ['id','name','config','includeDatabases','includeRedisDatabases'], "connections", "json"); if (res.success) { void message.success("导出成功"); } else if (res.message !== "Cancelled") { @@ -720,10 +808,18 @@ function App() { label: '外观设置...', icon: , onClick: () => setIsAppearanceModalOpen(true) + }, + { + key: 'shortcut-settings', + label: '快捷键管理...', + icon: , + onClick: () => setIsShortcutModalOpen(true) } ]; const [isAppearanceModalOpen, setIsAppearanceModalOpen] = useState(false); + const [isShortcutModalOpen, setIsShortcutModalOpen] = useState(false); + const [capturingShortcutAction, setCapturingShortcutAction] = useState(null); const [isProxyModalOpen, setIsProxyModalOpen] = useState(false); @@ -803,7 +899,7 @@ function App() { }; // Sidebar Resizing - const [sidebarWidth, setSidebarWidth] = useState(300); + const [sidebarWidth, setSidebarWidth] = useState(330); const sidebarDragRef = React.useRef<{ startX: number, startWidth: number } | null>(null); const rafRef = React.useRef(null); const ghostRef = React.useRef(null); @@ -935,6 +1031,113 @@ function App() { }; }, []); + useEffect(() => { + const handleOpenShortcutSettingsEvent = () => { + setIsShortcutModalOpen(true); + }; + window.addEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener); + return () => { + window.removeEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener); + }; + }, []); + + useEffect(() => { + const handleGlobalShortcut = (event: KeyboardEvent) => { + const matchedAction = SHORTCUT_ACTION_ORDER.find((action) => { + const binding = shortcutOptions[action]; + if (!binding?.enabled) { + return false; + } + if (isEditableElement(event.target) && !SHORTCUT_ACTION_META[action].allowInEditable) { + return false; + } + return isShortcutMatch(event, binding.combo); + }); + + if (!matchedAction) { + return; + } + + event.preventDefault(); + event.stopPropagation(); + + switch (matchedAction) { + case 'runQuery': + window.dispatchEvent(new CustomEvent('gonavi:run-active-query')); + break; + case 'focusSidebarSearch': + window.dispatchEvent(new CustomEvent('gonavi:focus-sidebar-search')); + break; + case 'newQueryTab': + handleNewQuery(); + break; + case 'toggleLogPanel': + setIsLogPanelOpen((prev) => !prev); + break; + case 'toggleTheme': + setTheme(themeMode === 'dark' ? 'light' : 'dark'); + break; + case 'openShortcutManager': + setIsShortcutModalOpen(true); + break; + } + }; + + window.addEventListener('keydown', handleGlobalShortcut); + return () => { + window.removeEventListener('keydown', handleGlobalShortcut); + }; + }, [handleNewQuery, shortcutOptions, themeMode, setTheme]); + + useEffect(() => { + if (!capturingShortcutAction) { + return; + } + + const handleShortcutCapture = (event: KeyboardEvent) => { + event.preventDefault(); + event.stopPropagation(); + + if (event.key === 'Escape') { + setCapturingShortcutAction(null); + return; + } + + const combo = eventToShortcut(event); + if (!combo) { + return; + } + if (!hasModifierKey(combo)) { + void message.warning('快捷键至少包含 Ctrl / Alt / Shift / Meta 之一'); + return; + } + + const normalizedCombo = normalizeShortcutCombo(combo); + const conflictAction = SHORTCUT_ACTION_ORDER.find((action) => { + if (action === capturingShortcutAction) { + return false; + } + const binding = shortcutOptions[action]; + if (!binding?.enabled) { + return false; + } + return normalizeShortcutCombo(binding.combo) === normalizedCombo; + }); + if (conflictAction) { + void message.warning(`与「${SHORTCUT_ACTION_META[conflictAction].label}」冲突,请换一个快捷键`); + return; + } + + updateShortcut(capturingShortcutAction, { combo: normalizedCombo, enabled: true }); + setCapturingShortcutAction(null); + }; + + window.addEventListener('keydown', handleShortcutCapture, true); + return () => { + window.removeEventListener('keydown', handleShortcutCapture, true); + }; + }, [capturingShortcutAction, shortcutOptions, updateShortcut]); + const linuxResizeHandleStyleBase = { position: 'fixed', zIndex: 12000, @@ -1221,6 +1424,9 @@ function App() {
版本:{aboutInfo?.version || '未知'}
作者:{aboutInfo?.author || '未知'}
+ {aboutInfo?.communityUrl ? ( + + ) : null}
更新状态:{aboutUpdateStatus || '未检查'}
@@ -1351,6 +1557,84 @@ function App() {
+ { + setIsShortcutModalOpen(false); + setCapturingShortcutAction(null); + }} + width={720} + footer={[ + , + , + ]} + > +
+
+ 点击“录制”后按下快捷键。按 Esc 可取消录制。建议至少包含一个修饰键(Ctrl/Alt/Shift/Meta)。 +
+ {SHORTCUT_ACTION_ORDER.map((action) => { + const meta = SHORTCUT_ACTION_META[action]; + const binding = shortcutOptions[action] ?? { combo: '', enabled: false }; + const isCapturing = capturingShortcutAction === action; + return ( +
+
+
{meta.label}
+
{meta.description}
+
+
+ + + updateShortcut(action, { enabled: checked })} + /> +
+
+ ); + })} +
+
= { vastbase: ['vastbase'], }; +const sslSupportedTypes = new Set([ + 'mysql', + 'mariadb', + 'diros', + 'sphinx', + 'dameng', + 'clickhouse', + 'postgres', + 'sqlserver', + 'oracle', + 'kingbase', + 'highgo', + 'vastbase', + 'mongodb', + 'redis', + 'tdengine', +]); + +const supportsSSLForType = (type: string) => sslSupportedTypes.has(String(type || '').trim().toLowerCase()); + const isFileDatabaseType = (type: string) => type === 'sqlite' || type === 'duckdb'; type DriverStatusSnapshot = { @@ -78,6 +98,7 @@ const ConnectionModal: React.FC<{ }> = ({ open, onClose, initialValues, onOpenDriverManager }) => { const [form] = Form.useForm(); const [loading, setLoading] = useState(false); + const [useSSL, setUseSSL] = useState(false); const [useSSH, setUseSSH] = useState(false); const [useProxy, setUseProxy] = useState(false); const [dbType, setDbType] = useState('mysql'); @@ -107,6 +128,17 @@ const ConnectionModal: React.FC<{ const mongoTopology = Form.useWatch('mongoTopology', form) || 'single'; const mongoSrv = Form.useWatch('mongoSrv', form) || false; const redisTopology = Form.useWatch('redisTopology', form) || 'single'; + const isMySQLLike = dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx'; + const isSSLType = supportsSSLForType(dbType); + const sslHintText = isMySQLLike + ? '当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL;本地自签证书场景可先用 Preferred 或 Skip Verify。' + : dbType === 'dameng' + ? '达梦驱动启用 SSL 需要客户端证书与私钥路径(sslCertPath / sslKeyPath)。' + : dbType === 'sqlserver' + ? 'SQL Server 推荐在生产环境使用 Required,并关闭 TrustServerCertificate。' + : dbType === 'mongodb' + ? 'MongoDB 可通过 TLS 保护连接,证书校验异常时可先用 Skip Verify 验证连通性。' + : '建议优先使用 Required;仅在测试环境或自签证书场景使用 Skip Verify。'; const getSectionBg = (darkHex: string) => { if (!darkMode) { @@ -364,7 +396,7 @@ const ConnectionModal: React.FC<{ uriText: string, expectedSchemes: string[], defaultPort: number, - ): { host: string; port: number; username: string; password: string; database: string } | null => { + ): { host: string; port: number; username: string; password: string; database: string; params: URLSearchParams } | null => { let parsed: ReturnType | null = null; for (const scheme of expectedSchemes) { parsed = parseMultiHostUri(uriText, scheme); @@ -392,6 +424,7 @@ const ConnectionModal: React.FC<{ username: parsed.username, password: parsed.password, database: parsed.database || '', + params: parsed.params, }; }; @@ -425,12 +458,22 @@ const ConnectionModal: React.FC<{ const primary = parseHostPort(hostList[0] || `localhost:${mysqlDefaultPort}`, mysqlDefaultPort); const timeoutValue = Number(parsed.params.get('timeout')); const topology = String(parsed.params.get('topology') || '').toLowerCase(); + const tlsValue = String(parsed.params.get('tls') || '').trim().toLowerCase(); + const sslMode = tlsValue === 'true' + ? 'required' + : tlsValue === 'skip-verify' + ? 'skip-verify' + : tlsValue === 'preferred' + ? 'preferred' + : 'disable'; return { host: primary?.host || 'localhost', port: primary?.port || mysqlDefaultPort, user: parsed.username, password: parsed.password, database: parsed.database || '', + useSSL: sslMode !== 'disable', + sslMode, mysqlTopology: hostList.length > 1 || topology === 'replica' ? 'replica' : 'single', mysqlReplicaHosts: hostList.slice(1), timeout: Number.isFinite(timeoutValue) && timeoutValue > 0 @@ -451,7 +494,7 @@ const ConnectionModal: React.FC<{ } if (type === 'redis') { - const parsed = parseMultiHostUri(trimmedUri, 'redis'); + const parsed = parseMultiHostUri(trimmedUri, 'redis') || parseMultiHostUri(trimmedUri, 'rediss'); if (!parsed) { return null; } @@ -469,10 +512,15 @@ const ConnectionModal: React.FC<{ const topologyParam = String(parsed.params.get('topology') || '').toLowerCase(); const dbText = String(parsed.database || '').trim().replace(/^\//, ''); const dbIndex = Number(dbText); + const isRediss = trimmedUri.toLowerCase().startsWith('rediss://'); + const skipVerifyText = String(parsed.params.get('skip_verify') || '').trim().toLowerCase(); + const skipVerify = skipVerifyText === '1' || skipVerifyText === 'true' || skipVerifyText === 'yes' || skipVerifyText === 'on'; return { host: primary?.host || 'localhost', port: primary?.port || 6379, password: parsed.password || '', + useSSL: isRediss, + sslMode: isRediss ? (skipVerify ? 'skip-verify' : 'required') : 'disable', redisTopology: hostList.length > 1 || topologyParam === 'cluster' ? 'cluster' : 'single', redisHosts: hostList.slice(1), redisDB: Number.isFinite(dbIndex) && dbIndex >= 0 && dbIndex <= 15 ? Math.trunc(dbIndex) : 0, @@ -501,12 +549,18 @@ const ConnectionModal: React.FC<{ ? { host: hostList[0] || 'localhost', port: 27017 } : parseHostPort(hostList[0] || 'localhost:27017', 27017); const timeoutMs = Number(parsed.params.get('connectTimeoutMS') || parsed.params.get('serverSelectionTimeoutMS')); + const tlsText = String(parsed.params.get('tls') || parsed.params.get('ssl') || '').trim().toLowerCase(); + const tlsInsecureText = String(parsed.params.get('tlsInsecure') || parsed.params.get('sslInsecure') || '').trim().toLowerCase(); + const tlsEnabled = tlsText === '1' || tlsText === 'true' || tlsText === 'yes' || tlsText === 'on'; + const tlsInsecure = tlsInsecureText === '1' || tlsInsecureText === 'true' || tlsInsecureText === 'yes' || tlsInsecureText === 'on'; return { host: primary?.host || 'localhost', port: primary?.port || 27017, user: parsed.username, password: parsed.password, database: parsed.database || '', + useSSL: tlsEnabled, + sslMode: tlsEnabled ? (tlsInsecure ? 'skip-verify' : 'required') : 'disable', mongoTopology: hostList.length > 1 || !!parsed.params.get('replicaSet') ? 'replica' : 'single', mongoHosts: hostList.slice(1), mongoSrv: isSrv, @@ -531,13 +585,94 @@ const ConnectionModal: React.FC<{ // Oracle 需要显式 service name,避免 URI 解析后放过必填校验。 return null; } - return { + const parsedValues: Record = { host: parsed.host, port: parsed.port, user: parsed.username, password: parsed.password, database: parsed.database, }; + + if (supportsSSLForType(type)) { + const normalizeBool = (raw: unknown) => { + const text = String(raw ?? '').trim().toLowerCase(); + return text === '1' || text === 'true' || text === 'yes' || text === 'on'; + }; + if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') { + const sslMode = String(parsed.params.get('sslmode') || '').trim().toLowerCase(); + if (sslMode) { + parsedValues.useSSL = sslMode !== 'disable' && sslMode !== 'false'; + parsedValues.sslMode = sslMode === 'disable' || sslMode === 'false' + ? 'disable' + : 'required'; + } + } else if (type === 'sqlserver') { + const encrypt = String(parsed.params.get('encrypt') || '').trim().toLowerCase(); + const trust = String(parsed.params.get('TrustServerCertificate') || parsed.params.get('trustservercertificate') || '').trim().toLowerCase(); + const encrypted = encrypt === 'true' || encrypt === 'mandatory' || encrypt === 'yes' || encrypt === '1' || encrypt === 'strict'; + if (encrypted) { + parsedValues.useSSL = true; + parsedValues.sslMode = trust === 'true' || trust === '1' || trust === 'yes' ? 'skip-verify' : 'required'; + } else if (encrypt) { + parsedValues.useSSL = false; + parsedValues.sslMode = 'disable'; + } + } else if (type === 'clickhouse') { + const secure = String(parsed.params.get('secure') || parsed.params.get('tls') || '').trim().toLowerCase(); + const skipVerify = normalizeBool(parsed.params.get('skip_verify')); + if (secure) { + parsedValues.useSSL = normalizeBool(secure); + parsedValues.sslMode = skipVerify ? 'skip-verify' : (parsedValues.useSSL ? 'required' : 'disable'); + } + } else if (type === 'dameng') { + const certPath = String( + parsed.params.get('SSL_CERT_PATH') + || parsed.params.get('ssl_cert_path') + || parsed.params.get('sslCertPath') + || '' + ).trim(); + const keyPath = String( + parsed.params.get('SSL_KEY_PATH') + || parsed.params.get('ssl_key_path') + || parsed.params.get('sslKeyPath') + || '' + ).trim(); + parsedValues.sslCertPath = certPath; + parsedValues.sslKeyPath = keyPath; + if (certPath || keyPath) { + parsedValues.useSSL = true; + parsedValues.sslMode = 'required'; + } + } else if (type === 'oracle') { + const ssl = String(parsed.params.get('SSL') || parsed.params.get('ssl') || '').trim().toLowerCase(); + const sslVerify = String( + parsed.params.get('SSL VERIFY') + || parsed.params.get('ssl verify') + || parsed.params.get('SSL_VERIFY') + || parsed.params.get('ssl_verify') + || '' + ).trim().toLowerCase(); + if (ssl) { + parsedValues.useSSL = normalizeBool(ssl); + if (!parsedValues.useSSL) { + parsedValues.sslMode = 'disable'; + } else { + parsedValues.sslMode = normalizeBool(sslVerify || 'true') ? 'required' : 'skip-verify'; + } + } + } else if (type === 'tdengine') { + const protocol = String(parsed.params.get('protocol') || '').trim().toLowerCase(); + const skipVerify = normalizeBool(parsed.params.get('skip_verify')); + if (protocol === 'wss') { + parsedValues.useSSL = true; + parsedValues.sslMode = skipVerify ? 'skip-verify' : 'required'; + } else if (protocol === 'ws') { + parsedValues.useSSL = false; + parsedValues.sslMode = 'disable'; + } + } + }; + return parsedValues; } return null; @@ -609,6 +744,16 @@ const ConnectionModal: React.FC<{ if (hosts.length > 1 || values.mysqlTopology === 'replica') { params.set('topology', 'replica'); } + if (values.useSSL) { + const mode = String(values.sslMode || 'preferred').trim().toLowerCase(); + if (mode === 'required') { + params.set('tls', 'true'); + } else if (mode === 'skip-verify') { + params.set('tls', 'skip-verify'); + } else { + params.set('tls', 'preferred'); + } + } if (Number.isFinite(timeout) && timeout > 0) { params.set('timeout', String(timeout)); } @@ -634,8 +779,15 @@ const ConnectionModal: React.FC<{ ? Math.max(0, Math.min(15, Math.trunc(Number(values.redisDB)))) : 0; const dbPath = `/${redisDB}`; + if (values.useSSL) { + const mode = String(values.sslMode || 'preferred').trim().toLowerCase(); + if (mode === 'skip-verify' || mode === 'preferred') { + params.set('skip_verify', 'true'); + } + } const query = params.toString(); - return `redis://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`; + const scheme = values.useSSL ? 'rediss' : 'redis'; + return `${scheme}://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`; } if (isFileDatabaseType(type)) { @@ -675,6 +827,15 @@ const ConnectionModal: React.FC<{ if (authMechanism) { params.set('authMechanism', authMechanism); } + if (values.useSSL) { + const mode = String(values.sslMode || 'preferred').trim().toLowerCase(); + params.set('tls', 'true'); + if (mode === 'skip-verify' || mode === 'preferred') { + params.set('tlsInsecure', 'true'); + } else { + params.delete('tlsInsecure'); + } + } if (Number.isFinite(timeout) && timeout > 0) { params.set('connectTimeoutMS', String(timeout * 1000)); params.set('serverSelectionTimeoutMS', String(timeout * 1000)); @@ -686,7 +847,45 @@ const ConnectionModal: React.FC<{ const scheme = type === 'postgres' ? 'postgresql' : type; const dbPath = database ? `/${encodeURIComponent(database)}` : ''; - return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}`; + const params = new URLSearchParams(); + if (supportsSSLForType(type) && values.useSSL) { + const mode = String(values.sslMode || 'preferred').trim().toLowerCase(); + if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') { + params.set('sslmode', 'require'); + } else if (type === 'sqlserver') { + params.set('encrypt', 'true'); + params.set('TrustServerCertificate', mode === 'skip-verify' || mode === 'preferred' ? 'true' : 'false'); + } else if (type === 'clickhouse') { + params.set('secure', 'true'); + if (mode === 'skip-verify' || mode === 'preferred') { + params.set('skip_verify', 'true'); + } + } else if (type === 'dameng') { + const certPath = String(values.sslCertPath || '').trim(); + const keyPath = String(values.sslKeyPath || '').trim(); + if (certPath) params.set('SSL_CERT_PATH', certPath); + if (keyPath) params.set('SSL_KEY_PATH', keyPath); + } else if (type === 'oracle') { + params.set('SSL', 'TRUE'); + params.set('SSL VERIFY', mode === 'required' ? 'TRUE' : 'FALSE'); + } else if (type === 'tdengine') { + params.set('protocol', 'wss'); + if (mode === 'skip-verify' || mode === 'preferred') { + params.set('skip_verify', 'true'); + } + } + } else if (supportsSSLForType(type)) { + if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') { + params.set('sslmode', 'disable'); + } else if (type === 'sqlserver') { + params.set('encrypt', 'disable'); + params.set('TrustServerCertificate', 'true'); + } else if (type === 'tdengine') { + params.set('protocol', 'ws'); + } + } + const query = params.toString(); + return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ''}`; }; const handleGenerateURI = () => { @@ -838,6 +1037,10 @@ const ConnectionModal: React.FC<{ uri: config.uri || '', includeDatabases: initialValues.includeDatabases, includeRedisDatabases: initialValues.includeRedisDatabases, + useSSL: !!config.useSSL, + sslMode: config.sslMode || 'preferred', + sslCertPath: config.sslCertPath || '', + sslKeyPath: config.sslKeyPath || '', useSSH: config.useSSH, sshHost: config.ssh?.host, sshPort: config.ssh?.port, @@ -871,6 +1074,7 @@ const ConnectionModal: React.FC<{ mongoReplicaUser: config.mongoReplicaUser || '', mongoReplicaPassword: config.mongoReplicaPassword || '' }); + setUseSSL(!!config.useSSL); setUseSSH(config.useSSH || false); setUseProxy(config.useProxy || false); setDbType(configType); @@ -882,6 +1086,7 @@ const ConnectionModal: React.FC<{ // Create mode: Start at step 1 setStep(1); form.resetFields(); + setUseSSL(false); setUseSSH(false); setUseProxy(false); setDbType('mysql'); @@ -932,6 +1137,7 @@ const ConnectionModal: React.FC<{ setLoading(false); form.resetFields(); + setUseSSL(false); setUseSSH(false); setUseProxy(false); setDbType('mysql'); @@ -1073,6 +1279,21 @@ const ConnectionModal: React.FC<{ const type = String(mergedValues.type || '').toLowerCase(); const defaultPort = getDefaultPortByType(type); const isFileDbType = isFileDatabaseType(type); + const sslCapableType = supportsSSLForType(type); + const sslModeRaw = String(mergedValues.sslMode || 'preferred').trim().toLowerCase(); + const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' = sslModeRaw === 'required' + ? 'required' + : sslModeRaw === 'skip-verify' + ? 'skip-verify' + : sslModeRaw === 'disable' + ? 'disable' + : 'preferred'; + const effectiveUseSSL = sslCapableType && !!mergedValues.useSSL; + const sslCertPath = sslCapableType ? String(mergedValues.sslCertPath || '').trim() : ''; + const sslKeyPath = sslCapableType ? String(mergedValues.sslKeyPath || '').trim() : ''; + if (type === 'dameng' && effectiveUseSSL && (!sslCertPath || !sslKeyPath)) { + throw new Error('达梦启用 SSL 时必须填写证书路径与私钥路径'); + } let primaryHost = 'localhost'; let primaryPort = defaultPort; @@ -1194,6 +1415,10 @@ const ConnectionModal: React.FC<{ password: keepPassword ? (mergedValues.password || "") : "", savePassword: savePassword, database: mergedValues.database || "", + useSSL: effectiveUseSSL, + sslMode: effectiveUseSSL ? sslMode : 'disable', + sslCertPath: sslCertPath, + sslKeyPath: sslKeyPath, useSSH: !!mergedValues.useSSH, ssh: sshConfig, useProxy: effectiveUseProxy, @@ -1233,6 +1458,7 @@ const ConnectionModal: React.FC<{ const defaultPort = getDefaultPortByType(type); if (isFileDatabaseType(type)) { + setUseSSL(false); setUseSSH(false); setUseProxy(false); form.setFieldsValue({ @@ -1241,6 +1467,10 @@ const ConnectionModal: React.FC<{ user: '', password: '', database: '', + useSSL: false, + sslMode: 'preferred', + sslCertPath: '', + sslKeyPath: '', useSSH: false, sshHost: '', sshPort: 22, @@ -1273,10 +1503,16 @@ const ConnectionModal: React.FC<{ }); } else if (type !== 'custom') { const defaultUser = type === 'clickhouse' ? 'default' : 'root'; + const sslCapableType = supportsSSLForType(type); + setUseSSL(false); form.setFieldsValue({ user: defaultUser, database: '', port: defaultPort, + useSSL: sslCapableType ? false : undefined, + sslMode: sslCapableType ? 'preferred' : undefined, + sslCertPath: sslCapableType ? '' : undefined, + sslKeyPath: sslCapableType ? '' : undefined, mysqlTopology: 'single', redisTopology: 'single', mongoTopology: 'single', @@ -1420,6 +1656,10 @@ const ConnectionModal: React.FC<{ port: 3306, database: '', user: 'root', + useSSL: false, + sslMode: 'preferred', + sslCertPath: '', + sslKeyPath: '', useSSH: false, sshPort: 22, useProxy: false, @@ -1451,6 +1691,7 @@ const ConnectionModal: React.FC<{ if (changed.uri !== undefined || changed.type !== undefined) { setUriFeedback(null); } + if (changed.useSSL !== undefined) setUseSSL(changed.useSSL); if (changed.useSSH !== undefined) setUseSSH(changed.useSSH); if (changed.useProxy !== undefined) setUseProxy(changed.useProxy); if (changed.proxyType !== undefined) { @@ -1835,6 +2076,56 @@ const ConnectionModal: React.FC<{ {!isFileDb && ( <> + {isSSLType && ( + <> + + + 使用 SSL/TLS + + {useSSL && ( +
+ + + + + + + + )} + + {sslHintText} + +
+ )} + + )} + 使用 SSH 隧道 (SSH Tunnel) diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index 05ccbba..5adf70f 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -537,6 +537,7 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => { { key: 'exp-xlsx', label: 'Excel', onClick: () => handleExportSelected('xlsx', record) }, { key: 'exp-json', label: 'JSON', onClick: () => handleExportSelected('json', record) }, { key: 'exp-md', label: 'Markdown', onClick: () => handleExportSelected('md', record) }, + { key: 'exp-html', label: 'HTML', onClick: () => handleExportSelected('html', record) }, ] } ]; @@ -577,7 +578,9 @@ interface DataGridProps { // Filtering showFilter?: boolean; onToggleFilter?: () => void; + exportSqlWithFilter?: string; onApplyFilter?: (conditions: GridFilterCondition[]) => void; + appliedFilterConditions?: FilterCondition[]; } type GridFilterCondition = FilterCondition & { @@ -595,9 +598,9 @@ type ColumnMeta = { comment: string; }; -const DataGrid: React.FC = ({ +const DataGrid: React.FC = ({ data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false, - onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, onApplyFilter + onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions }) => { const connections = useStore(state => state.connections); const addSqlLog = useStore(state => state.addSqlLog); @@ -620,6 +623,8 @@ const DataGrid: React.FC = ({ const isQueryResultExport = exportScope === 'queryResult'; const canImport = exportScope === 'table' && !!tableName; const canExport = !!connectionId && (isQueryResultExport || !!tableName); + const filteredExportSql = useMemo(() => String(exportSqlWithFilter || '').trim(), [exportSqlWithFilter]); + const hasFilteredExportSql = exportScope === 'table' && filteredExportSql.length > 0; // Background Helper const getBg = (darkHex: string) => { @@ -1060,10 +1065,40 @@ const DataGrid: React.FC = ({ const [modifiedRows, setModifiedRows] = useState>({}); const [deletedRowKeys, setDeletedRowKeys] = useState>(new Set()); + const normalizeFilterLogic = useCallback((logic: unknown): 'AND' | 'OR' => { + return String(logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND'; + }, []); + + const normalizeGridFilterConditions = useCallback((conditions?: FilterCondition[]): GridFilterCondition[] => { + if (!Array.isArray(conditions)) return []; + return conditions.map((cond, index) => { + const fallbackId = index + 1; + const nextId = Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : fallbackId; + const op = String(cond?.op || '='); + const rawColumn = String(cond?.column || ''); + return { + id: nextId, + enabled: cond?.enabled !== false, + logic: normalizeFilterLogic(cond?.logic), + column: rawColumn || (op === 'CUSTOM' ? '' : String(columnNames[0] || '')), + op, + value: String(cond?.value ?? ''), + value2: String(cond?.value2 ?? ''), + }; + }); + }, [columnNames, normalizeFilterLogic]); + // Filter State const [filterConditions, setFilterConditions] = useState([]); const [nextFilterId, setNextFilterId] = useState(1); + useEffect(() => { + const nextConditions = normalizeGridFilterConditions(appliedFilterConditions); + setFilterConditions(nextConditions); + const maxId = nextConditions.reduce((max, cond) => (cond.id > max ? cond.id : max), 0); + setNextFilterId(Math.max(1, maxId + 1)); + }, [appliedFilterConditions, normalizeGridFilterConditions]); + const selectedRowKeysRef = useRef(selectedRowKeys); const displayDataRef = useRef([]); @@ -2481,6 +2516,23 @@ const DataGrid: React.FC = ({ }); }; + const handleExportFilteredAll = async (format: string) => { + if (!connectionId || !tableName) return; + if (!filteredExportSql) { + message.warning('当前未应用筛选条件'); + return; + } + if (!supportsSqlQueryExport) { + message.error('当前数据源不支持按筛选结果导出'); + return; + } + if (hasChanges) { + message.warning("当前存在未提交修改,筛选结果导出基于数据库已提交数据。"); + } + + await exportByQuery(filteredExportSql, format, `${tableName || 'export'}_filtered`); + }; + const handleImport = async () => { if (!connectionId || !tableName) return; const config = buildConnConfig(); @@ -2526,6 +2578,10 @@ const DataGrid: React.FC = ({ { value: 'NOT_IN', label: '不在列表' }, { value: 'CUSTOM', label: '[自定义]' }, ]), []); + const filterLogicOptions = useMemo(() => ([ + { value: 'AND', label: '且 (AND)' }, + { value: 'OR', label: '或 (OR)' }, + ]), []); const isNoValueOp = useCallback((op: string) => ( op === 'IS_NULL' || op === 'IS_NOT_NULL' || op === 'IS_EMPTY' || op === 'IS_NOT_EMPTY' @@ -2534,7 +2590,18 @@ const DataGrid: React.FC = ({ const isListOp = useCallback((op: string) => op === 'IN' || op === 'NOT_IN', []); const addFilter = () => { - setFilterConditions([...filterConditions, { id: nextFilterId, enabled: true, column: columnNames[0] || '', op: '=', value: '', value2: '' }]); + setFilterConditions([ + ...filterConditions, + { + id: nextFilterId, + enabled: true, + logic: 'AND', + column: columnNames[0] || '', + op: '=', + value: '', + value2: '', + } + ]); setNextFilterId(nextFilterId + 1); }; const updateFilter = (id: number, field: keyof GridFilterCondition, val: string | boolean) => { @@ -2562,11 +2629,28 @@ const DataGrid: React.FC = ({ if (onApplyFilter) onApplyFilter(filterConditions); }; - const exportMenu: MenuProps['items'] = [ + const exportMenu: MenuProps['items'] = hasFilteredExportSql ? [ + { type: 'group', label: '筛选结果', children: [ + { key: 'filtered-csv', label: 'CSV', onClick: () => handleExportFilteredAll('csv') }, + { key: 'filtered-xlsx', label: 'Excel (XLSX)', onClick: () => handleExportFilteredAll('xlsx') }, + { key: 'filtered-json', label: 'JSON', onClick: () => handleExportFilteredAll('json') }, + { key: 'filtered-md', label: 'Markdown', onClick: () => handleExportFilteredAll('md') }, + { key: 'filtered-html', label: 'HTML', onClick: () => handleExportFilteredAll('html') }, + ]}, + { type: 'divider' }, + { type: 'group', label: '全表', children: [ + { key: 'table-csv', label: 'CSV', onClick: () => handleExport('csv') }, + { key: 'table-xlsx', label: 'Excel (XLSX)', onClick: () => handleExport('xlsx') }, + { key: 'table-json', label: 'JSON', onClick: () => handleExport('json') }, + { key: 'table-md', label: 'Markdown', onClick: () => handleExport('md') }, + { key: 'table-html', label: 'HTML', onClick: () => handleExport('html') }, + ]}, + ] : [ { key: 'csv', label: 'CSV', onClick: () => handleExport('csv') }, { key: 'xlsx', label: 'Excel (XLSX)', onClick: () => handleExport('xlsx') }, { key: 'json', label: 'JSON', onClick: () => handleExport('json') }, { key: 'md', label: 'Markdown', onClick: () => handleExport('md') }, + { key: 'html', label: 'HTML', onClick: () => handleExport('html') }, ]; const columnInfoSettingContent = ( @@ -2705,29 +2789,31 @@ const DataGrid: React.FC = ({ horizontalSyncSourceRef.current = ''; }, []); - const handleExternalHorizontalWheel = useCallback((event: React.WheelEvent) => { + // 非虚拟模式:外部水平滚动条的 wheel 处理(通过原生事件绑定,确保 preventDefault 生效) + useEffect(() => { const externalScroll = externalHScrollRef.current; - if (!(externalScroll instanceof HTMLDivElement)) { - return; - } - const dominantDelta = Math.abs(event.deltaX) > Math.abs(event.deltaY) ? event.deltaX : event.deltaY; - if (!Number.isFinite(dominantDelta) || Math.abs(dominantDelta) < 0.5) { - return; - } + if (!externalScroll || !horizontalScrollVisible) return; - const maxScrollLeft = Math.max(0, externalScroll.scrollWidth - externalScroll.clientWidth); - if (maxScrollLeft <= 0) { - return; - } + const handleExternalWheel = (e: WheelEvent) => { + // 鼠标在水平滚动条区域时,始终阻止垂直滚动冒泡 + e.preventDefault(); + e.stopPropagation(); - const nextScrollLeft = Math.max(0, Math.min(maxScrollLeft, externalScroll.scrollLeft + dominantDelta)); - if (Math.abs(nextScrollLeft - externalScroll.scrollLeft) < 0.5) { - return; - } + const dominantDelta = Math.abs(e.deltaX) > Math.abs(e.deltaY) ? e.deltaX : e.deltaY; + if (!Number.isFinite(dominantDelta) || Math.abs(dominantDelta) < 0.5) return; - event.preventDefault(); - externalScroll.scrollLeft = nextScrollLeft; - }, []); + const maxScrollLeft = Math.max(0, externalScroll.scrollWidth - externalScroll.clientWidth); + if (maxScrollLeft <= 0) return; + + const nextScrollLeft = Math.max(0, Math.min(maxScrollLeft, externalScroll.scrollLeft + dominantDelta)); + externalScroll.scrollLeft = nextScrollLeft; + }; + + externalScroll.addEventListener('wheel', handleExternalWheel, { passive: false, capture: true }); + return () => { + externalScroll.removeEventListener('wheel', handleExternalWheel, { capture: true } as EventListenerOptions); + }; + }, [horizontalScrollVisible]); useEffect(() => { if (viewMode !== 'table') return; @@ -2735,19 +2821,24 @@ const DataGrid: React.FC = ({ return () => cancelAnimationFrame(rafId); }, [viewMode, totalWidth, mergedDisplayData.length, recalculateTableMetrics]); - // 虚拟模式下,为 rc-virtual-list 的内置水平滚动条添加鼠标滚轮支持 - // rc-virtual-list 的 ScrollBar 组件原生只支持拖拽,不支持 wheel 事件 - // 方案:使用 MutationObserver 发现滚动条元素后直接绑定 wheel 事件 + // 虚拟模式下,在容器级别监听 wheel 事件,当鼠标在底部水平滚动条区域时拦截并转为水平滚动 useEffect(() => { if (viewMode !== 'table' || !enableVirtual) return; const container = tableContainerRef.current; if (!container) return; - let currentScrollbarEl: HTMLElement | null = null; + // 滚动条区域高度:滚动条高度 + 间距 + 容错 + const scrollbarZoneHeight = floatingScrollbarHeight + floatingScrollbarGap + 8; - const handleScrollbarWheel = (e: WheelEvent) => { - const innerEl = container.querySelector('.rc-virtual-list-holder-inner') as HTMLElement | null; - const holderEl = container.querySelector('.rc-virtual-list-holder') as HTMLElement | null; + const handleContainerWheel = (e: WheelEvent) => { + // 判断鼠标是否在底部滚动条区域 + const containerRect = container.getBoundingClientRect(); + if (e.clientY < containerRect.bottom - scrollbarZoneHeight) return; + + // 适配 antd 的虚拟列表类名 + const holderEl = container.querySelector('.ant-table-tbody-virtual-holder') as HTMLElement | null; + const innerEl = holderEl?.querySelector('.ant-table-tbody-virtual-holder-inner') as HTMLElement | null; + if (!innerEl || !holderEl) return; const dominantDelta = Math.abs(e.deltaX) > Math.abs(e.deltaY) ? e.deltaX : e.deltaY; @@ -2769,12 +2860,13 @@ const DataGrid: React.FC = ({ innerEl.style.marginLeft = `${-newOffset}px`; // 同步 scrollbar thumb 位置 - if (currentScrollbarEl && maxScroll > 0) { - const thumbEl = currentScrollbarEl.querySelector('[class*="scrollbar-thumb"]') as HTMLElement | null; + const scrollbarEl = container.querySelector('.ant-table-tbody-virtual-scrollbar-horizontal') as HTMLElement | null; + if (scrollbarEl && maxScroll > 0) { + const thumbEl = scrollbarEl.querySelector('[class*="scrollbar-thumb"]') as HTMLElement | null; if (thumbEl) { const ratio = newOffset / maxScroll; const thumbWidth = parseFloat(thumbEl.style.width) || thumbEl.offsetWidth; - const trackWidth = currentScrollbarEl.clientWidth; + const trackWidth = scrollbarEl.clientWidth; const thumbMaxOffset = trackWidth - thumbWidth; thumbEl.style.left = `${ratio * thumbMaxOffset}px`; } @@ -2787,33 +2879,12 @@ const DataGrid: React.FC = ({ } }; - const bindScrollbar = () => { - const el = container.querySelector('.ant-table-tbody-virtual-scrollbar-horizontal') as HTMLElement | null; - if (el && el !== currentScrollbarEl) { - if (currentScrollbarEl) { - currentScrollbarEl.removeEventListener('wheel', handleScrollbarWheel); - } - currentScrollbarEl = el; - el.addEventListener('wheel', handleScrollbarWheel, { passive: false }); - } - }; - - // 初次尝试绑定 - bindScrollbar(); - - // 使用 MutationObserver 监听 DOM 变化,确保即使元素延迟渲染也能绑定 - const observer = new MutationObserver(() => { - bindScrollbar(); - }); - observer.observe(container, { childList: true, subtree: true }); + container.addEventListener('wheel', handleContainerWheel, { passive: false, capture: true }); return () => { - observer.disconnect(); - if (currentScrollbarEl) { - currentScrollbarEl.removeEventListener('wheel', handleScrollbarWheel); - } + container.removeEventListener('wheel', handleContainerWheel, { capture: true } as EventListenerOptions); }; - }, [viewMode, enableVirtual, tableScrollX, mergedDisplayData.length]); + }, [viewMode, enableVirtual, tableScrollX, floatingScrollbarHeight, floatingScrollbarGap]); useEffect(() => { if (viewMode !== 'table') return; @@ -3041,7 +3112,7 @@ const DataGrid: React.FC = ({ background: 'transparent', boxSizing: 'border-box', }}> - {filterConditions.map(cond => ( + {filterConditions.map((cond, condIndex) => (
= ({ > 启用 - updateFilter(cond.id, 'logic', v)} + options={filterLogicOptions as any} + /> + )} + = ({ className="data-grid-external-hscroll" aria-hidden={!horizontalScrollVisible} onScroll={applyExternalScrollToTableTargets} - onWheel={handleExternalHorizontalWheel} style={{ opacity: horizontalScrollVisible ? 1 : 0, pointerEvents: horizontalScrollVisible ? 'auto' : 'none', @@ -3552,6 +3642,21 @@ const DataGrid: React.FC = ({ > 导出为 JSON
+
e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} + onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} + onClick={() => { + if (cellContextMenu.record) handleExportSelected('html', cellContextMenu.record); + setCellContextMenu(prev => ({ ...prev, visible: false })); + }} + > + 导出为 HTML +
, document.body )} diff --git a/frontend/src/components/DataViewer.tsx b/frontend/src/components/DataViewer.tsx index 8950629..adc2240 100644 --- a/frontend/src/components/DataViewer.tsx +++ b/frontend/src/components/DataViewer.tsx @@ -1,10 +1,11 @@ -import React, { useEffect, useState, useCallback, useRef } from 'react'; +import React, { useEffect, useState, useCallback, useRef, useMemo } from 'react'; import { message } from 'antd'; import { TabData, ColumnDefinition } from '../types'; import { useStore } from '../store'; import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App'; import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { buildOrderBySQL, buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql'; +import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; type ViewerPaginationState = { @@ -151,6 +152,37 @@ const reverseOrderBySQL = (orderBySQL: string): string => { return ` ORDER BY ${parts.join(', ')}`; }; +type ViewerFilterSnapshot = { + showFilter: boolean; + conditions: FilterCondition[]; +}; + +const viewerFilterSnapshotsByTab = new Map(); + +const normalizeViewerFilterConditions = (conditions: FilterCondition[] | undefined): FilterCondition[] => { + if (!Array.isArray(conditions)) return []; + return conditions.map((cond) => ({ + id: Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : undefined, + enabled: cond?.enabled !== false, + logic: String(cond?.logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND', + column: String(cond?.column || ''), + op: String(cond?.op || '='), + value: String(cond?.value ?? ''), + value2: String(cond?.value2 ?? ''), + })); +}; + +const getViewerFilterSnapshot = (tabId: string): ViewerFilterSnapshot => { + const cached = viewerFilterSnapshotsByTab.get(String(tabId || '').trim()); + if (!cached) { + return { showFilter: false, conditions: [] }; + } + return { + showFilter: cached.showFilter === true, + conditions: normalizeViewerFilterConditions(cached.conditions), + }; +}; + const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const [data, setData] = useState([]); const [columnNames, setColumnNames] = useState([]); @@ -185,14 +217,27 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null); - const [showFilter, setShowFilter] = useState(false); - const [filterConditions, setFilterConditions] = useState([]); + const [showFilter, setShowFilter] = useState(() => getViewerFilterSnapshot(tab.id).showFilter); + const [filterConditions, setFilterConditions] = useState(() => getViewerFilterSnapshot(tab.id).conditions); const duckdbSafeSelectCacheRef = useRef>({}); const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config; const currentConnCaps = getDataSourceCapabilities(currentConnConfig); const currentConnType = currentConnCaps.type; const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult; + useEffect(() => { + const snapshot = getViewerFilterSnapshot(tab.id); + setShowFilter(snapshot.showFilter); + setFilterConditions(snapshot.conditions); + }, [tab.id]); + + useEffect(() => { + viewerFilterSnapshotsByTab.set(tab.id, { + showFilter, + conditions: normalizeViewerFilterConditions(filterConditions), + }); + }, [tab.id, showFilter, filterConditions]); + useEffect(() => { setPkColumns([]); pkKeyRef.current = ''; @@ -315,42 +360,67 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const dbName = tab.dbName || ''; const tableName = tab.tableName || ''; + const isMongoDB = dbTypeLower === 'mongodb'; + let mongoFilter: Record | undefined; + if (isMongoDB) { + try { + mongoFilter = buildMongoFilter(filterConditions); + } catch (e: any) { + message.error(`Mongo 筛选条件无效:${String(e?.message || e || '解析失败')}`); + if (fetchSeqRef.current === seq) setLoading(false); + return; + } + } - const whereSQL = buildWhereSQL(dbType, filterConditions); - - const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; - - const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; - const orderBySQL = buildOrderBySQL(dbType, sortInfo, pkColumns); - let sql = `${baseSql}${orderBySQL}`; + const whereSQL = isMongoDB + ? JSON.stringify(mongoFilter || {}) + : buildWhereSQL(dbType, filterConditions); + const countSql = isMongoDB + ? buildMongoCountCommand(tableName, mongoFilter || {}) + : `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; + const orderBySQL = isMongoDB ? '' : buildOrderBySQL(dbType, sortInfo, pkColumns); const totalRows = Number(pagination.total); const hasFiniteTotal = Number.isFinite(totalRows) && totalRows >= 0; const totalKnown = pagination.totalKnown && hasFiniteTotal; const totalPages = hasFiniteTotal ? Math.max(1, Math.ceil(totalRows / size)) : 0; const currentPage = totalPages > 0 ? Math.min(Math.max(1, page), totalPages) : Math.max(1, page); const offset = (currentPage - 1) * size; - const isClickHouse = dbTypeLower === 'clickhouse'; + const isClickHouse = !isMongoDB && dbTypeLower === 'clickhouse'; const reverseOrderSQL = isClickHouse ? reverseOrderBySQL(orderBySQL) : ''; let useClickHouseReversePagination = false; let clickHouseReverseLimit = 0; let clickHouseReverseHasMore = false; - // ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景, - // 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET,并在前端翻转结果。 - if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) { - const pageRowCount = Math.max(0, Math.min(size, totalRows - offset)); - if (pageRowCount > 0) { - const tailOffset = Math.max(0, totalRows - (offset + pageRowCount)); - if (tailOffset < offset) { - sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`; - useClickHouseReversePagination = true; - clickHouseReverseLimit = pageRowCount; - clickHouseReverseHasMore = currentPage < totalPages; + let sql = ''; + if (isMongoDB) { + const mongoSort = buildMongoSort(sortInfo, pkColumns); + sql = buildMongoFindCommand({ + collection: tableName, + filter: mongoFilter || {}, + sort: mongoSort, + limit: size + 1, + skip: offset, + }); + } else { + const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; + sql = `${baseSql}${orderBySQL}`; + // ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景, + // 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET,并在前端翻转结果。 + if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) { + const pageRowCount = Math.max(0, Math.min(size, totalRows - offset)); + if (pageRowCount > 0) { + const tailOffset = Math.max(0, totalRows - (offset + pageRowCount)); + if (tailOffset < offset) { + sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`; + useClickHouseReversePagination = true; + clickHouseReverseLimit = pageRowCount; + clickHouseReverseHasMore = currentPage < totalPages; + } } } - } - if (!useClickHouseReversePagination) { - // 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。 - sql += ` LIMIT ${size + 1} OFFSET ${offset}`; + if (!useClickHouseReversePagination) { + // 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。 + sql += ` LIMIT ${size + 1} OFFSET ${offset}`; + } } const requestStartTime = Date.now(); @@ -676,6 +746,24 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const handleToggleFilter = useCallback(() => setShowFilter(prev => !prev), []); const handleApplyFilter = useCallback((conditions: FilterCondition[]) => setFilterConditions(conditions), []); + const exportSqlWithFilter = useMemo(() => { + const tableName = String(tab.tableName || '').trim(); + const dbType = String(currentConnConfig?.type || '').trim(); + if (!tableName || !dbType) return ''; + + const whereSQL = buildWhereSQL(dbType, filterConditions); + if (!whereSQL) return ''; + + let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; + sql += buildOrderBySQL(dbType, sortInfo, pkColumns); + const normalizedType = dbType.toLowerCase(); + const hasExplicitSort = !!sortInfo?.columnKey && (sortInfo?.order === 'ascend' || sortInfo?.order === 'descend'); + if (hasExplicitSort && (normalizedType === 'mysql' || normalizedType === 'mariadb')) { + sql = withSortBufferTuningSQL(normalizedType, sql, 32 * 1024 * 1024); + } + return sql; + }, [tab.tableName, currentConnConfig?.type, filterConditions, sortInfo, pkColumns]); + useEffect(() => { fetchData(1, pagination.pageSize); }, [tab, sortInfo, filterConditions]); // Initial load and re-load on sort/filter @@ -700,8 +788,10 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { showFilter={showFilter} onToggleFilter={handleToggleFilter} onApplyFilter={handleApplyFilter} + appliedFilterConditions={filterConditions} readOnly={forceReadOnly} sortInfoExternal={sortInfo} + exportSqlWithFilter={exportSqlWithFilter || undefined} /> ); diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index 2d6a36e..2e66344 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -1,13 +1,16 @@ import React, { useState, useEffect, useRef, useMemo } from 'react'; import Editor, { OnMount } from '@monaco-editor/react'; import { Button, message, Modal, Input, Form, Dropdown, MenuProps, Tooltip, Select, Tabs } from 'antd'; -import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined } from '@ant-design/icons'; +import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined, StopOutlined } from '@ant-design/icons'; import { format } from 'sql-formatter'; +import { v4 as uuidv4 } from 'uuid'; import { TabData, ColumnDefinition } from '../types'; import { useStore } from '../store'; -import { DBQuery, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns } from '../../wailsjs/go/app/App'; +import { DBQuery, DBQueryWithCancel, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App'; import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; +import { convertMongoShellToJsonCommand } from '../utils/mongodb'; +import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts'; const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { const [query, setQuery] = useState(tab.query || 'SELECT * FROM '); @@ -30,7 +33,9 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { const [activeResultKey, setActiveResultKey] = useState(''); const [loading, setLoading] = useState(false); + const [currentQueryId, setCurrentQueryId] = useState(''); const runSeqRef = useRef(0); + const currentQueryIdRef = useRef(''); const [isSaveModalOpen, setIsSaveModalOpen] = useState(false); const [saveForm] = Form.useForm(); @@ -65,6 +70,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { const setSqlFormatOptions = useStore(state => state.setSqlFormatOptions); const queryOptions = useStore(state => state.queryOptions); const setQueryOptions = useStore(state => state.setQueryOptions); + const shortcutOptions = useStore(state => state.shortcutOptions); + const activeTabId = useStore(state => state.activeTabId); useEffect(() => { currentConnectionIdRef.current = currentConnectionId; @@ -186,6 +193,17 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { fetchMetadata(); }, [currentConnectionId, connections, dbList]); // dbList 变化时触发重新加载 + // Query ID management helpers + const setQueryId = (id: string) => { + currentQueryIdRef.current = id; + setCurrentQueryId(id); + }; + + const clearQueryId = () => { + currentQueryIdRef.current = ''; + setCurrentQueryId(''); + }; + // Handle Resizing const handleMouseDown = (e: React.MouseEvent) => { e.preventDefault(); @@ -254,6 +272,19 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { return parts[parts.length - 1] || raw; }; + const splitSchemaAndTable = (qualified: string): { schema: string; table: string } => { + const raw = normalizeQualifiedName(qualified); + if (!raw) return { schema: '', table: '' }; + const parts = raw.split('.').filter(Boolean); + if (parts.length >= 2) { + return { + schema: parts[parts.length - 2] || '', + table: parts[parts.length - 1] || '', + }; + } + return { schema: '', table: parts[0] || '' }; + }; + const buildConnConfig = () => { const connId = currentConnectionIdRef.current; const conn = connectionsRef.current.find(c => c.id === connId); @@ -326,13 +357,14 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { if (qualifierMatch) { const qualifier = stripQuotes(qualifierMatch[1]); const prefix = (qualifierMatch[2] || '').toLowerCase(); + const qualifierLower = qualifier.toLowerCase(); // 首先检查 qualifier 是否是数据库名(跨库表提示) const visibleDbs = visibleDbsRef.current; - if (visibleDbs.some(db => db.toLowerCase() === qualifier.toLowerCase())) { + if (visibleDbs.some(db => db.toLowerCase() === qualifierLower)) { // qualifier 是数据库名,提示该库的表 const tables = tablesRef.current.filter(t => - (t.dbName || '').toLowerCase() === qualifier.toLowerCase() + (t.dbName || '').toLowerCase() === qualifierLower ); const filtered = prefix ? tables.filter(t => (t.tableName || '').toLowerCase().startsWith(prefix)) @@ -349,6 +381,34 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { return { suggestions }; } + // qualifier 是 schema(如 dbo/public)时,仅补全表名,避免输入 dbo. 后再补成 dbo.dbo.table + const schemaTables = tablesRef.current + .map(t => { + const parsed = splitSchemaAndTable(t.tableName || ''); + return { + dbName: t.dbName || '', + schema: parsed.schema, + table: parsed.table, + }; + }) + .filter(t => t.schema.toLowerCase() === qualifierLower && !!t.table); + + if (schemaTables.length > 0) { + const filtered = prefix + ? schemaTables.filter(t => t.table.toLowerCase().startsWith(prefix)) + : schemaTables; + + const suggestions = filtered.map(t => ({ + label: t.table, + kind: monaco.languages.CompletionItemKind.Class, + insertText: t.table, + detail: `Table (${t.dbName}${t.schema ? '.' + t.schema : ''})`, + range, + sortText: '0' + t.table + })); + return { suggestions }; + } + // 否则检查是否是表别名或表名,提示列 const reserved = new Set([ 'where', 'on', 'group', 'order', 'limit', 'having', @@ -517,6 +577,12 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { icon: sqlFormatOptions.keywordCase === 'lower' ? '✓' : undefined, onClick: () => setSqlFormatOptions({ keywordCase: 'lower' }) }, + { type: 'divider' }, + { + key: 'shortcut-settings', + label: '快捷键管理...', + onClick: () => window.dispatchEvent(new CustomEvent('gonavi:open-shortcut-settings')), + }, ]; const splitSQLStatements = (sql: string): string[] => { @@ -984,6 +1050,16 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { message.error("请先选择数据库"); return; } + // 如果已有查询在运行,先取消它 + if (currentQueryIdRef.current) { + try { + await CancelQuery(currentQueryIdRef.current); + } catch (error) { + // 忽略取消错误,可能查询已完成 + } + // 清除旧查询ID + clearQueryId(); + } const runSeq = ++runSeqRef.current; setLoading(true); const runStartTime = Date.now(); @@ -1011,7 +1087,15 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { try { const rawSQL = getSelectedSQL() || query; - const statements = splitSQLStatements(rawSQL); + const dbType = String((config as any).type || 'mysql'); + const normalizedDbType = dbType.trim().toLowerCase(); + const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';'); + const splitInput = normalizedDbType === 'mongodb' + ? normalizedRawSQL + .replace(/^\s*\/\/.*$/gm, '') + .replace(/^\s*#.*$/gm, '') + : normalizedRawSQL; + const statements = splitSQLStatements(splitInput); if (statements.length === 0) { message.info('没有可执行的 SQL。'); setResultSets([]); @@ -1021,7 +1105,6 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { const nextResultSets: ResultSet[] = []; const maxRows = Number(queryOptions?.maxRows) || 0; - const dbType = String((config as any).type || 'mysql'); const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult; const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0; const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0; @@ -1035,9 +1118,35 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { const limitApplied = shouldAutoLimit && wantsLimitProbe; const limited = limitApplied ? applyAutoLimit(rawStatement, dbType, probeLimit) : { sql: rawStatement, applied: false, maxRows: probeLimit }; - const executedSql = limited.sql; + let executedSql = limited.sql; + if (String(dbType || '').trim().toLowerCase() === 'mongodb') { + const shellConvert = convertMongoShellToJsonCommand(executedSql); + if (shellConvert.recognized) { + if (shellConvert.error) { + const prefix = statements.length > 1 ? `第 ${idx + 1} 条语句执行失败:` : ''; + message.error(prefix + shellConvert.error); + setResultSets([]); + setActiveResultKey(''); + return; + } + if (shellConvert.command) { + executedSql = shellConvert.command; + } + } + } const startTime = Date.now(); - const res = await DBQuery(config as any, currentDb, executedSql); + + // Generate query ID for cancellation using backend UUID with fallback + let queryId: string; + try { + queryId = await GenerateQueryID(); + } catch (error) { + console.warn('GenerateQueryID failed, using local UUID fallback:', error); + queryId = 'query-' + uuidv4(); + } + setQueryId(queryId); + + const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId); const duration = Date.now() - startTime; addSqlLog({ @@ -1052,6 +1161,32 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { }); if (!res.success) { + // 检查是否为查询取消错误 + const errorMsg = res.message.toLowerCase(); + const isCancelledError = errorMsg.includes('context canceled') || + errorMsg.includes('查询已取消') || + errorMsg.includes('canceled') || + errorMsg.includes('cancelled') || + errorMsg.includes('statement canceled') || + errorMsg.includes('sql: statement canceled'); + + // 确保不是超时错误 + const isTimeoutError = errorMsg.includes('context deadline exceeded') || + errorMsg.includes('timeout') || + errorMsg.includes('超时') || + errorMsg.includes('deadline exceeded'); + + if (isCancelledError && !isTimeoutError) { + // 查询已被用户取消,不显示错误消息,清理状态 + setResultSets([]); + setActiveResultKey(''); + // 清除查询ID,与handleCancel保持一致 + if (currentQueryIdRef.current) { + clearQueryId(); + } + return; + } + const prefix = statements.length > 1 ? `第 ${idx + 1} 条语句执行失败:` : ''; message.error(prefix + res.message); setResultSets([]); @@ -1157,9 +1292,75 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { setActiveResultKey(''); } finally { if (runSeqRef.current === runSeq) setLoading(false); + // Clear query ID after execution completes + clearQueryId(); } }; + const handleCancel = async () => { + if (!currentQueryIdRef.current) { + message.warning('没有正在运行的查询可取消'); + return; + } + const queryIdToCancel = currentQueryIdRef.current; + try { + const res = await CancelQuery(queryIdToCancel); + if (res.success) { + message.success('查询已取消'); + // Clear query ID after successful cancellation + if (currentQueryIdRef.current === queryIdToCancel) { + clearQueryId() + } + } else { + message.warning(res.message); + } + } catch (error: any) { + message.error('取消查询失败: ' + error.message); + } + }; + + useEffect(() => { + const binding = shortcutOptions.runQuery; + if (!binding?.enabled || !binding.combo) { + return; + } + + const handleRunShortcut = (event: KeyboardEvent) => { + if (activeTabId !== tab.id) { + return; + } + if (!isShortcutMatch(event, binding.combo)) { + return; + } + const editorHasFocus = !!editorRef.current?.hasTextFocus?.(); + if (!editorHasFocus && !isEditableElement(event.target)) { + return; + } + event.preventDefault(); + event.stopPropagation(); + void handleRun(); + }; + + window.addEventListener('keydown', handleRunShortcut); + return () => { + window.removeEventListener('keydown', handleRunShortcut); + }; + }, [activeTabId, tab.id, shortcutOptions.runQuery, handleRun]); + + useEffect(() => { + const handleRunActiveQuery = () => { + if (activeTabId !== tab.id) { + return; + } + void handleRun(); + }; + + window.addEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener); + return () => { + window.removeEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener); + }; + }, [activeTabId, tab.id, handleRun]); + const handleSave = async () => { try { const values = await saveForm.validateFields(); @@ -1271,9 +1472,24 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { ]} /> - + + + + + {loading && ( + + )} + + + + - {/* Toolbar for batch operations - always visible */} -
+ {/* Toolbar */} +
+ @@ -2772,7 +3330,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> size="small" icon={} onClick={() => openBatchDatabaseModal()} - style={{ flex: 1 }} + style={{ flex: '1 1 auto' }} > 批量操作库 @@ -2781,6 +3339,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
node.type === 'connection' || node.type === 'tag' + }} + onDrop={handleDrop} loadData={onLoadData} treeData={displayTreeData} onDoubleClick={onDoubleClick} @@ -2809,6 +3372,60 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> )} + { + createTagForm.validateFields().then(values => { + if (renameViewTarget?.type === 'tag') { + // Rename + updateConnectionTag({ + ...renameViewTarget.dataRef, + name: values.name, + connectionIds: values.connectionIds || [] + }); + // update cross-connections + const allOtherTagsIds = connectionTags.filter(t => t.id !== renameViewTarget.dataRef.id).flatMap(t => t.connectionIds); + (values.connectionIds || []).forEach((cid: string) => { + if (allOtherTagsIds.includes(cid)) { + moveConnectionToTag(cid, renameViewTarget.dataRef.id); + } + }); + } else { + // Create + const tagId = Date.now().toString(); + addConnectionTag({ + id: tagId, + name: values.name, + connectionIds: values.connectionIds || [] + }); + (values.connectionIds || []).forEach((cid: string) => { + moveConnectionToTag(cid, tagId); + }); + } + setIsCreateTagModalOpen(false); + }); + }} + onCancel={() => setIsCreateTagModalOpen(false)} + > +
+ + + + + + + {connections.map(conn => ( + + {conn.name} {conn.config.host ? `(${conn.config.host})` : ''} + + ))} + + + +
+
+ = ({ tab }) => { const darkMode = theme === 'dark'; const resizeGuideColor = darkMode ? '#f6c453' : '#1890ff'; const readOnly = !!tab.readOnly; + const panelRadius = 10; + const panelFrameColor = darkMode ? 'rgba(0, 0, 0, 0.18)' : 'rgba(0, 0, 0, 0.12)'; + const panelToolbarBorder = darkMode ? 'rgba(255, 255, 255, 0.12)' : 'rgba(0, 0, 0, 0.10)'; + const panelToolbarBg = darkMode ? 'rgba(20, 20, 20, 0.35)' : 'rgba(255, 255, 255, 0.72)'; + const panelBodyBg = darkMode ? 'rgba(0, 0, 0, 0.24)' : 'rgba(255, 255, 255, 0.82)'; + const focusRowBg = darkMode ? 'rgba(246, 196, 83, 0.22)' : 'rgba(24, 144, 255, 0.12)'; const [tableHeight, setTableHeight] = useState(500); const containerRef = useRef(null); + const pendingFocusColumnKeyRef = useRef(null); + const focusHighlightTimerRef = useRef(null); + const [focusColumnKey, setFocusColumnKey] = useState(''); const openCommentEditor = useCallback((record: EditableColumn) => { if (!record?._key) return; @@ -346,6 +355,61 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => { setSelectedColumnRowKeys(prev => prev.filter(key => columns.some(c => c._key === key))); }, [columns]); + useEffect(() => { + return () => { + if (focusHighlightTimerRef.current !== null) { + window.clearTimeout(focusHighlightTimerRef.current); + } + }; + }, []); + + const focusColumnRow = useCallback((targetKey: string): boolean => { + if (activeKey !== 'columns') return false; + const tableBody = containerRef.current?.querySelector('.ant-table-body') as HTMLElement | null; + if (!tableBody) return false; + const row = tableBody.querySelector(`tr[data-row-key="${targetKey}"]`) as HTMLTableRowElement | null; + if (!row) return false; + + row.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); + setFocusColumnKey(targetKey); + if (focusHighlightTimerRef.current !== null) { + window.clearTimeout(focusHighlightTimerRef.current); + } + focusHighlightTimerRef.current = window.setTimeout(() => { + setFocusColumnKey(prev => (prev === targetKey ? '' : prev)); + }, 1600); + + if (!readOnly) { + const firstInput = row.querySelector('input') as HTMLInputElement | null; + if (firstInput) { + firstInput.focus(); + firstInput.select(); + } + } + return true; + }, [activeKey, readOnly]); + + useEffect(() => { + const pendingKey = pendingFocusColumnKeyRef.current; + if (!pendingKey || activeKey !== 'columns') return; + + let cancelled = false; + const tryFocus = () => { + if (cancelled) return; + if (focusColumnRow(pendingKey)) { + pendingFocusColumnKeyRef.current = null; + } + }; + + const timerA = window.setTimeout(tryFocus, 0); + const timerB = window.setTimeout(tryFocus, 96); + return () => { + cancelled = true; + window.clearTimeout(timerA); + window.clearTimeout(timerB); + }; + }, [activeKey, columns, focusColumnRow]); + // Initial Columns Definition useEffect(() => { const initialCols = [ @@ -886,21 +950,46 @@ ${selectedTrigger.statement}`; })); }; - const handleAddColumn = () => { - const newCol: EditableColumn = { - name: isNewTable ? 'new_column' : `new_col_${columns.length + 1}`, - type: 'varchar(255)', - nullable: 'YES', - key: '', - extra: '', - comment: '', - default: '', - _key: `new-${Date.now()}`, - isNew: true, - isAutoIncrement: false - }; - setColumns([...columns, newCol]); - }; + const createNewColumn = useCallback((indexHint: number): EditableColumn => ({ + name: isNewTable ? 'new_column' : `new_col_${indexHint}`, + type: 'varchar(255)', + nullable: 'YES', + key: '', + extra: '', + comment: '', + default: '', + _key: `new-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, + isNew: true, + isAutoIncrement: false + }), [isNewTable]); + + const handleAddColumn = useCallback((insertAfterKey?: string) => { + const newCol = createNewColumn(columns.length + 1); + setColumns(prev => { + const next = [...prev]; + if (insertAfterKey) { + const insertIndex = next.findIndex(col => col._key === insertAfterKey); + if (insertIndex >= 0) { + next.splice(insertIndex + 1, 0, newCol); + return next; + } + } + next.push(newCol); + return next; + }); + setSelectedColumnRowKeys([newCol._key]); + pendingFocusColumnKeyRef.current = newCol._key; + }, [columns.length, createNewColumn]); + + const handleAddColumnAfterSelected = useCallback(() => { + const selectedSet = new Set(selectedColumnRowKeys); + const anchor = columns.find(col => selectedSet.has(col._key)); + if (!anchor) { + message.warning('请先选择一个字段,再执行插入。'); + return; + } + handleAddColumn(anchor._key); + }, [columns, handleAddColumn, selectedColumnRowKeys]); const handleDeleteColumn = (key: string) => { setColumns(prev => prev.filter(c => c._key !== key)); @@ -1920,22 +2009,35 @@ END;`; })); const columnsTabContent = ( -
+
{readOnly ? ( record._key === focusColumnKey ? 'table-designer-focus-row' : ''} size="small" pagination={false} loading={loading} scroll={{ y: tableHeight }} - bordered + bordered={false} components={{ header: { cell: ResizableTitle, @@ -1953,11 +2055,12 @@ END;`; onChange: (nextSelectedRowKeys) => setSelectedColumnRowKeys(nextSelectedRowKeys as string[]), }} rowKey="_key" + rowClassName={(record: EditableColumn) => record._key === focusColumnKey ? 'table-designer-focus-row' : ''} size="small" pagination={false} loading={loading} scroll={{ y: tableHeight }} - bordered + bordered={false} components={{ body: { row: SortableRow }, header: { cell: ResizableTitle } @@ -1985,8 +2088,63 @@ END;`; ); return ( -
-
+
+ +
{isNewTable && ( <> )} - {!readOnly && } - {!isNewTable && } + {!readOnly && } + {!isNewTable && } {!isNewTable && !readOnly && supportsTableCommentOps() && ( - + )} - {!readOnly && } + {!readOnly && } {!readOnly && ( + )} + {!readOnly && ( +
+ `); err != nil { + return err + } + + for _, col := range columns { + if _, err := fmt.Fprintf(w, "", html.EscapeString(col)); err != nil { + return err + } + } + + if _, err := w.WriteString(``); err != nil { + return err + } + + if len(data) == 0 { + colspan := len(columns) + if colspan <= 0 { + colspan = 1 + } + if _, err := fmt.Fprintf(w, ``, colspan); err != nil { + return err + } + } else { + for _, rowMap := range data { + if _, err := w.WriteString(""); err != nil { + return err + } + for _, col := range columns { + if _, err := fmt.Fprintf(w, "", formatExportHTMLCell(rowMap[col])); err != nil { + return err + } + } + if _, err := w.WriteString(""); err != nil { + return err + } + } + } + + if _, err := w.WriteString(`
%s
(0 rows)
%s
+
+
+ +`); err != nil { + return err + } + + return w.Flush() +} + func formatExportCellText(val interface{}) string { if val == nil { return "NULL" diff --git a/internal/app/methods_file_export_test.go b/internal/app/methods_file_export_test.go index 6a0b1b4..5ddaf9c 100644 --- a/internal/app/methods_file_export_test.go +++ b/internal/app/methods_file_export_test.go @@ -203,3 +203,73 @@ func TestGetExportQueryTimeout_CustomClickHouseUsesLongerMinimum(t *testing.T) { t.Fatalf("custom clickhouse 导出超时下限异常,want=%s got=%s", minClickHouseExportQueryTimeout, timeout) } } + +func TestWriteRowsToFile_HTML_EscapeAndStyle(t *testing.T) { + f, err := os.CreateTemp("", "gonavi-export-*.html") + if err != nil { + t.Fatalf("创建临时文件失败: %v", err) + } + defer os.Remove(f.Name()) + defer f.Close() + + data := []map[string]interface{}{ + { + "name": "", + "note": "line1\nline2", + "nullable": nil, + }, + } + columns := []string{"name", "note", "nullable"} + + if err := writeRowsToFile(f, data, columns, "html"); err != nil { + t.Fatalf("写入 html 失败: %v", err) + } + + contentBytes, err := os.ReadFile(f.Name()) + if err != nil { + t.Fatalf("读取 html 失败: %v", err) + } + content := string(contentBytes) + + if !strings.Contains(content, "") { + t.Fatalf("html 导出缺少 doctype: %s", content) + } + if !strings.Contains(content, "position: sticky") { + t.Fatalf("html 导出缺少表头吸顶样式: %s", content) + } + if !strings.Contains(content, "tbody tr:nth-child(even)") { + t.Fatalf("html 导出缺少斑马纹样式: %s", content) + } + if !strings.Contains(content, "<script>alert(1)</script>") { + t.Fatalf("html 导出未进行 XSS 转义: %s", content) + } + if strings.Contains(content, "") { + t.Fatalf("html 导出包含未转义脚本: %s", content) + } + if !strings.Contains(content, "line1
line2") { + t.Fatalf("html 导出换行未转为
: %s", content) + } + if !strings.Contains(content, "NULL") { + t.Fatalf("html 导出空值显示异常: %s", content) + } +} + +func TestWriteRowsToFile_HTML_EscapeHeader(t *testing.T) { + f, err := os.CreateTemp("", "gonavi-export-*.html") + if err != nil { + t.Fatalf("创建临时文件失败: %v", err) + } + defer os.Remove(f.Name()) + defer f.Close() + + columnName := "name" + data := []map[string]interface{}{{columnName: "ok"}} + if err := writeRowsToFile(f, data, []string{columnName}, "html"); err != nil { + t.Fatalf("写入 html 失败: %v", err) + } + contentBytes, _ := os.ReadFile(f.Name()) + content := string(contentBytes) + if !strings.Contains(content, "<b>name</b>") || strings.Contains(content, "name") { + t.Fatalf("html 表头未正确转义: %s", content) + } +} diff --git a/internal/app/methods_update.go b/internal/app/methods_update.go index bd98ace..ae2bdfe 100644 --- a/internal/app/methods_update.go +++ b/internal/app/methods_update.go @@ -51,12 +51,13 @@ type UpdateInfo struct { } type AppInfo struct { - Version string `json:"version"` - Author string `json:"author"` - RepoURL string `json:"repoUrl,omitempty"` - IssueURL string `json:"issueUrl,omitempty"` - ReleaseURL string `json:"releaseUrl,omitempty"` - BuildTime string `json:"buildTime,omitempty"` + Version string `json:"version"` + Author string `json:"author"` + RepoURL string `json:"repoUrl,omitempty"` + IssueURL string `json:"issueUrl,omitempty"` + ReleaseURL string `json:"releaseUrl,omitempty"` + CommunityURL string `json:"communityUrl,omitempty"` + BuildTime string `json:"buildTime,omitempty"` } type updateDownloadResult struct { @@ -137,12 +138,13 @@ func (a *App) CheckForUpdates() connection.QueryResult { func (a *App) GetAppInfo() connection.QueryResult { info := AppInfo{ - Version: getCurrentVersion(), - Author: getCurrentAuthor(), - RepoURL: "https://github.com/" + updateRepo, - IssueURL: "https://github.com/" + updateRepo + "/issues", - ReleaseURL: "https://github.com/" + updateRepo + "/releases", - BuildTime: strings.TrimSpace(AppBuildTime), + Version: getCurrentVersion(), + Author: getCurrentAuthor(), + RepoURL: "https://github.com/" + updateRepo, + IssueURL: "https://github.com/" + updateRepo + "/issues", + ReleaseURL: "https://github.com/" + updateRepo + "/releases", + CommunityURL: "https://aibook.ren", + BuildTime: strings.TrimSpace(AppBuildTime), } return connection.QueryResult{Success: true, Message: "OK", Data: info} } diff --git a/internal/connection/types.go b/internal/connection/types.go index 20b4cbb..bc88873 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -27,6 +27,10 @@ type ConnectionConfig struct { Password string `json:"password"` SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection Database string `json:"database"` + UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch + SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable + SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng) + SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng) UseSSH bool `json:"useSSH"` SSH SSHConfig `json:"ssh"` UseProxy bool `json:"useProxy,omitempty"` @@ -55,6 +59,7 @@ type QueryResult struct { Message string `json:"message"` Data interface{} `json:"data"` Fields []string `json:"fields,omitempty"` + QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation } // ColumnDefinition represents a table column diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index 8061041..dcf18e6 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -107,7 +107,7 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig if readTimeout < minClickHouseReadTimeout { readTimeout = minClickHouseReadTimeout } - return &clickhouse.Options{ + opts := &clickhouse.Options{ Addr: []string{ net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), }, @@ -119,6 +119,10 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig DialTimeout: connectTimeout, ReadTimeout: readTimeout, } + if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil { + opts.TLS = tlsConfig + } + return opts } func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error { @@ -165,13 +169,30 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error { logger.Infof("ClickHouse 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) } - c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(runConfig)) - - if err := c.Ping(); err != nil { - _ = c.Close() - return fmt.Errorf("连接建立后验证失败:%w", err) + attempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + attempts = append(attempts, withSSLDisabled(runConfig)) } - return nil + + var failures []string + for idx, attempt := range attempts { + c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(attempt)) + if err := c.Ping(); err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } + continue + } + if idx > 0 { + logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接") + } + return nil + } + + _ = c.Close() + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (c *ClickHouseDB) Close() error { diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go index 6aa06b0..5080540 100644 --- a/internal/db/dameng_impl.go +++ b/internal/db/dameng_impl.go @@ -36,6 +36,14 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string { if config.Database != "" { q.Set("schema", config.Database) } + if config.UseSSL { + if certPath := strings.TrimSpace(config.SSLCertPath); certPath != "" { + q.Set("SSL_CERT_PATH", certPath) + } + if keyPath := strings.TrimSpace(config.SSLKeyPath); keyPath != "" { + q.Set("SSL_KEY_PATH", keyPath) + } + } if escapedPassword != config.Password { // 达梦驱动要求:密码包含特殊字符时,password 需 PathEscape,并添加 escapeProcess=true 让驱动解码。 q.Set("escapeProcess", "true") @@ -50,8 +58,12 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string { } func (d *DamengDB) Connect(config connection.ConnectionConfig) error { - var dsn string - var err error + runConfig := config + if runConfig.UseSSL { + if strings.TrimSpace(runConfig.SSLCertPath) == "" || strings.TrimSpace(runConfig.SSLKeyPath) == "" { + return fmt.Errorf("达梦启用 SSL 需要同时配置证书路径(sslCertPath)与私钥路径(sslKeyPath)") + } + } if config.UseSSH { // Create SSH tunnel with local port forwarding @@ -80,22 +92,37 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error { localConfig.Port = port localConfig.UseSSH = false - dsn = d.getDSN(localConfig) + runConfig = localConfig logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) - } else { - dsn = d.getDSN(config) } - db, err := sql.Open("dm", dsn) - if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) + attempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + attempts = append(attempts, withSSLDisabled(runConfig)) } - d.conn = db - d.pingTimeout = getConnectTimeout(config) - if err := d.Ping(); err != nil { - return fmt.Errorf("连接建立后验证失败:%w", err) + + var failures []string + for idx, attempt := range attempts { + dsn := d.getDSN(attempt) + db, err := sql.Open("dm", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) + continue + } + d.conn = db + d.pingTimeout = getConnectTimeout(attempt) + if err := d.Ping(); err != nil { + _ = db.Close() + d.conn = nil + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + continue + } + if idx > 0 { + logger.Warnf("达梦 SSL 优先连接失败,已回退至明文连接") + } + return nil } - return nil + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (d *DamengDB) Close() error { diff --git a/internal/db/diros_impl.go b/internal/db/diros_impl.go index 38ac270..07bed73 100644 --- a/internal/db/diros_impl.go +++ b/internal/db/diros_impl.go @@ -151,9 +151,10 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) string { } timeout := getConnectTimeoutSeconds(config) + tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds", - config.User, config.Password, protocol, address, database, timeout) + return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) } func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { diff --git a/internal/db/dsn_test.go b/internal/db/dsn_test.go index 4dadf3a..6bed9fe 100644 --- a/internal/db/dsn_test.go +++ b/internal/db/dsn_test.go @@ -33,6 +33,44 @@ func TestPostgresDSN_EscapesPassword(t *testing.T) { } } +func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) { + p := &PostgresDB{} + cfg := connection.ConnectionConfig{ + Type: "postgres", + Host: "127.0.0.1", + Port: 5432, + User: "user", + Password: "pass", + Database: "db", + UseSSL: true, + SSLMode: "required", + } + + dsn := p.getDSN(cfg) + if !strings.Contains(dsn, "sslmode=require") { + t.Fatalf("dsn 缺少 sslmode=require 参数:%s", dsn) + } +} + +func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) { + m := &MySQLDB{} + cfg := connection.ConnectionConfig{ + Type: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Password: "pass", + Database: "db", + UseSSL: true, + SSLMode: "required", + } + + dsn := m.getDSN(cfg) + if !strings.Contains(dsn, "tls=true") { + t.Fatalf("dsn 缺少 tls=true 参数:%s", dsn) + } +} + func TestOracleDSN_EscapesUserAndPassword(t *testing.T) { o := &OracleDB{} cfg := connection.ConnectionConfig{ @@ -82,6 +120,30 @@ func TestDamengDSN_EscapesPasswordAndEnablesEscapeProcess(t *testing.T) { } } +func TestDamengDSN_AppendsSSLCertAndKeyParams(t *testing.T) { + d := &DamengDB{} + cfg := connection.ConnectionConfig{ + Type: "dameng", + Host: "127.0.0.1", + Port: 5236, + User: "SYSDBA", + Password: "pass", + Database: "DBName", + UseSSL: true, + SSLMode: "required", + SSLCertPath: "C:\\certs\\client-cert.pem", + SSLKeyPath: "C:\\certs\\client-key.pem", + } + + dsn := d.getDSN(cfg) + if !strings.Contains(dsn, "SSL_CERT_PATH=") { + t.Fatalf("dsn 缺少 SSL_CERT_PATH 参数:%s", dsn) + } + if !strings.Contains(dsn, "SSL_KEY_PATH=") { + t.Fatalf("dsn 缺少 SSL_KEY_PATH 参数:%s", dsn) + } +} + func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) { k := &KingbaseDB{} cfg := connection.ConnectionConfig{ @@ -116,6 +178,47 @@ func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) { } } +func TestTDengineDSN_UsesSecureWebSocketWhenSSLEnabled(t *testing.T) { + td := &TDengineDB{} + cfg := connection.ConnectionConfig{ + Type: "tdengine", + Host: "127.0.0.1", + Port: 6041, + User: "root", + Password: "taosdata", + Database: "power", + UseSSL: true, + SSLMode: "required", + } + + dsn := td.getDSN(cfg) + if !strings.HasPrefix(dsn, "root:taosdata@wss(127.0.0.1:6041)/power") { + t.Fatalf("tdengine ssl dsn 格式不正确:%s", dsn) + } +} + +func TestSQLServerDSN_EncryptMapping(t *testing.T) { + s := &SqlServerDB{} + cfg := connection.ConnectionConfig{ + Type: "sqlserver", + Host: "127.0.0.1", + Port: 1433, + User: "sa", + Password: "pass", + Database: "master", + UseSSL: true, + SSLMode: "required", + } + + dsn := s.getDSN(cfg) + if !strings.Contains(strings.ToLower(dsn), "encrypt=true") { + t.Fatalf("sqlserver dsn 缺少 encrypt=true:%s", dsn) + } + if !strings.Contains(strings.ToLower(dsn), "trustservercertificate=false") { + t.Fatalf("sqlserver dsn 缺少 TrustServerCertificate=false:%s", dsn) + } +} + func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) { c := &ClickHouseDB{} cfg := normalizeClickHouseConfig(connection.ConnectionConfig{ diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index 8e982a6..0343565 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -42,7 +42,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string { } u.User = url.UserPassword(config.User, config.Password) q := url.Values{} - q.Set("sslmode", "disable") + q.Set("sslmode", resolvePostgresSSLMode(config)) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) u.RawQuery = q.Encode() @@ -50,7 +50,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string { } func (h *HighGoDB) Connect(config connection.ConnectionConfig) error { - var dsn string + runConfig := config if config.UseSSH { logger.Infof("HighGo 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) @@ -76,23 +76,37 @@ func (h *HighGoDB) Connect(config connection.ConnectionConfig) error { localConfig.Port = port localConfig.UseSSH = false - dsn = h.getDSN(localConfig) + runConfig = localConfig logger.Infof("HighGo 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) - } else { - dsn = h.getDSN(config) } - db, err := sql.Open("highgo", dsn) - if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) + attempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + attempts = append(attempts, withSSLDisabled(runConfig)) } - h.conn = db - h.pingTimeout = getConnectTimeout(config) - if err := h.Ping(); err != nil { - return fmt.Errorf("连接建立后验证失败:%w", err) + var failures []string + for idx, attempt := range attempts { + dsn := h.getDSN(attempt) + db, err := sql.Open("highgo", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) + continue + } + h.conn = db + h.pingTimeout = getConnectTimeout(attempt) + if err := h.Ping(); err != nil { + _ = db.Close() + h.conn = nil + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + continue + } + if idx > 0 { + logger.Warnf("HighGo SSL 优先连接失败,已回退至明文连接") + } + return nil } - return nil + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (h *HighGoDB) Close() error { diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 2726665..f1357a8 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -65,12 +65,13 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { port := config.Port // Construct DSN - dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d", + dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d", quoteConnValue(address), port, quoteConnValue(config.User), quoteConnValue(config.Password), quoteConnValue(config.Database), + quoteConnValue(resolvePostgresSSLMode(config)), getConnectTimeoutSeconds(config), ) @@ -78,8 +79,7 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { } func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { - var dsn string - var err error + runConfig := config if config.UseSSH { // Create SSH tunnel with local port forwarding @@ -108,23 +108,37 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { localConfig.Port = port localConfig.UseSSH = false - dsn = k.getDSN(localConfig) + runConfig = localConfig logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) - } else { - dsn = k.getDSN(config) } - // Open using "kingbase" driver - db, err := sql.Open("kingbase", dsn) - if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) + attempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + attempts = append(attempts, withSSLDisabled(runConfig)) } - k.conn = db - k.pingTimeout = getConnectTimeout(config) - if err := k.Ping(); err != nil { - return fmt.Errorf("连接建立后验证失败:%w", err) + + var failures []string + for idx, attempt := range attempts { + dsn := k.getDSN(attempt) + db, err := sql.Open("kingbase", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) + continue + } + k.conn = db + k.pingTimeout = getConnectTimeout(attempt) + if err := k.Ping(); err != nil { + _ = db.Close() + k.conn = nil + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + continue + } + if idx > 0 { + logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接") + } + return nil } - return nil + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (k *KingbaseDB) Close() error { diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 4d8457b..1e316ad 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "fmt" + "net/url" "strings" "time" @@ -40,9 +41,10 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) string { } timeout := getConnectTimeoutSeconds(config) + tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds", - config.User, config.Password, protocol, address, database, timeout) + return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) } func (m *MariaDB) Connect(config connection.ConnectionConfig) error { diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index 4b2f630..5c853f6 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -4,6 +4,7 @@ package db import ( "context" + "crypto/tls" "fmt" "net" "net/url" @@ -327,35 +328,60 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { m.database = "admin" } - attemptConfigs := buildMongoAuthAttempts(connectConfig) + sslAttempts := []connection.ConnectionConfig{connectConfig} + if shouldTrySSLPreferredFallback(connectConfig) { + sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig)) + } + var errorDetails []string - for index, attemptConfig := range attemptConfigs { - authLabel := "主库凭据" - if index > 0 { - authLabel = "从库凭据" + for sslIndex, sslConfig := range sslAttempts { + sslLabel := "SSL" + if sslIndex > 0 { + sslLabel = "明文回退" } - uri := m.getURI(attemptConfig) - clientOpts := options.Client().ApplyURI(uri) - if attemptConfig.UseProxy { - clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy}) - } - client, err := mongo.Connect(clientOpts) - if err != nil { - errorDetails = append(errorDetails, fmt.Sprintf("%s连接失败: %v", authLabel, err)) - continue - } + attemptConfigs := buildMongoAuthAttempts(sslConfig) + for index, attemptConfig := range attemptConfigs { + authLabel := "主库凭据" + if index > 0 { + authLabel = "从库凭据" + } - m.client = client - if err := m.Ping(); err != nil { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - _ = client.Disconnect(ctx) - cancel() - m.client = nil - errorDetails = append(errorDetails, fmt.Sprintf("%s验证失败: %v", authLabel, err)) - continue + if sslIndex > 0 { + attemptConfig.URI = "" + } + uri := m.getURI(attemptConfig) + clientOpts := options.Client().ApplyURI(uri) + tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig) + if tlsEnabled { + clientOpts.SetTLSConfig(&tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: tlsInsecure, + }) + } + if attemptConfig.UseProxy { + clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy}) + } + client, err := mongo.Connect(clientOpts) + if err != nil { + errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)) + continue + } + + m.client = client + if err := m.Ping(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + _ = client.Disconnect(ctx) + cancel() + m.client = nil + errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)) + continue + } + if sslIndex > 0 { + logger.Warnf("MongoDB SSL 优先连接失败,已回退至明文连接") + } + return nil } - return nil } if len(errorDetails) > 0 { diff --git a/internal/db/mongodb_impl_v1.go b/internal/db/mongodb_impl_v1.go new file mode 100644 index 0000000..26e110a --- /dev/null +++ b/internal/db/mongodb_impl_v1.go @@ -0,0 +1,1187 @@ +//go:build gonavi_mongodb_driver_v1 + +package db + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/url" + "sort" + "strconv" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + proxytunnel "GoNavi-Wails/internal/proxy" + "GoNavi-Wails/internal/ssh" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" +) + +type MongoDBV1 struct { + client *mongo.Client + database string + pingTimeout time.Duration + forwarder *ssh.LocalForwarder +} + +type mongoProxyDialer struct { + proxyConfig connection.ProxyConfig +} + +func (d *mongoProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return proxytunnel.DialContext(ctx, d.proxyConfig, network, address) +} + +const defaultMongoPort = 27017 + +func normalizeMongoAddress(host string, port int) string { + h := strings.TrimSpace(host) + if h == "" { + h = "localhost" + } + p := port + if p <= 0 { + p = defaultMongoPort + } + return fmt.Sprintf("%s:%d", h, p) +} + +func normalizeMongoSeed(raw string, defaultPort int, useSRV bool) (string, bool) { + host, port, ok := parseHostPortWithDefault(raw, defaultPort) + if !ok { + return "", false + } + + if useSRV { + normalized := strings.TrimSpace(host) + if normalized == "" { + return "", false + } + return normalized, true + } + + return normalizeMongoAddress(host, port), true +} + +func collectMongoSeeds(config connection.ConnectionConfig) []string { + defaultPort := config.Port + if defaultPort <= 0 { + defaultPort = defaultMongoPort + } + useSRV := config.MongoSRV + + candidates := make([]string, 0, len(config.Hosts)+1) + if len(config.Hosts) > 0 { + candidates = append(candidates, config.Hosts...) + } else { + if useSRV { + candidates = append(candidates, strings.TrimSpace(config.Host)) + } else { + candidates = append(candidates, normalizeMongoAddress(config.Host, defaultPort)) + } + } + + result := make([]string, 0, len(candidates)) + seen := make(map[string]struct{}, len(candidates)) + for _, entry := range candidates { + normalized, ok := normalizeMongoSeed(entry, defaultPort, useSRV) + if !ok { + continue + } + if _, exists := seen[normalized]; exists { + continue + } + seen[normalized] = struct{}{} + result = append(result, normalized) + } + + return result +} + +func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConfig { + uriText := strings.TrimSpace(config.URI) + if uriText == "" { + return config + } + lowerURI := strings.ToLower(uriText) + if strings.HasPrefix(lowerURI, "mongodb+srv://") { + config.MongoSRV = true + } + if !strings.HasPrefix(lowerURI, "mongodb://") && !strings.HasPrefix(lowerURI, "mongodb+srv://") { + return config + } + + parsed, err := url.Parse(uriText) + if err != nil { + return config + } + + if parsed.User != nil { + if config.User == "" { + config.User = parsed.User.Username() + } + if pass, ok := parsed.User.Password(); ok && config.Password == "" { + config.Password = pass + } + } + + if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" { + config.Database = dbName + } + + defaultPort := config.Port + if defaultPort <= 0 { + defaultPort = defaultMongoPort + } + hostsFromURI := make([]string, 0, 4) + hostText := strings.TrimSpace(parsed.Host) + if hostText != "" { + for _, entry := range strings.Split(hostText, ",") { + normalized, ok := normalizeMongoSeed(entry, defaultPort, config.MongoSRV) + if ok { + hostsFromURI = append(hostsFromURI, normalized) + } + } + } + + if len(config.Hosts) == 0 && len(hostsFromURI) > 0 { + config.Hosts = hostsFromURI + } + if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 { + host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort) + if ok { + config.Host = host + config.Port = port + } + } + + query := parsed.Query() + if config.AuthSource == "" { + config.AuthSource = strings.TrimSpace(query.Get("authSource")) + } + if config.ReadPreference == "" { + config.ReadPreference = strings.TrimSpace(query.Get("readPreference")) + } + if config.ReplicaSet == "" { + config.ReplicaSet = strings.TrimSpace(query.Get("replicaSet")) + } + if config.MongoAuthMechanism == "" { + config.MongoAuthMechanism = strings.TrimSpace(query.Get("authMechanism")) + } + if config.Topology == "" { + if len(config.Hosts) > 1 || strings.TrimSpace(config.ReplicaSet) != "" { + config.Topology = "replica" + } else { + config.Topology = "single" + } + } + + return config +} + +func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string { + if strings.TrimSpace(config.URI) != "" { + return strings.TrimSpace(config.URI) + } + + seeds := collectMongoSeeds(config) + if len(seeds) == 0 { + if config.MongoSRV { + seed := strings.TrimSpace(config.Host) + if seed == "" { + seed = "localhost" + } + seeds = append(seeds, seed) + } else { + seeds = append(seeds, normalizeMongoAddress(config.Host, config.Port)) + } + } + + scheme := "mongodb" + if config.MongoSRV { + scheme = "mongodb+srv" + } + hostText := strings.Join(seeds, ",") + uri := fmt.Sprintf("%s://%s", scheme, hostText) + + if config.User != "" { + var userinfo *url.Userinfo + if config.Password != "" { + userinfo = url.UserPassword(config.User, config.Password) + } else { + userinfo = url.User(config.User) + } + uri = fmt.Sprintf("%s://%s@%s", scheme, userinfo.String(), hostText) + } + + path := "/" + if strings.TrimSpace(config.Database) != "" { + path = "/" + url.PathEscape(strings.TrimSpace(config.Database)) + } + uri += path + + params := url.Values{} + timeout := getConnectTimeoutSeconds(config) + params.Set("connectTimeoutMS", strconv.Itoa(timeout*1000)) + params.Set("serverSelectionTimeoutMS", strconv.Itoa(timeout*1000)) + + authSource := strings.TrimSpace(config.AuthSource) + if authSource == "" && strings.TrimSpace(config.Database) != "" { + authSource = strings.TrimSpace(config.Database) + } + if authSource == "" { + authSource = "admin" + } + params.Set("authSource", authSource) + + if replicaSet := strings.TrimSpace(config.ReplicaSet); replicaSet != "" { + params.Set("replicaSet", replicaSet) + } + if readPreference := strings.TrimSpace(config.ReadPreference); readPreference != "" { + params.Set("readPreference", readPreference) + } + if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" { + params.Set("authMechanism", authMechanism) + } + + if encoded := params.Encode(); encoded != "" { + uri += "?" + encoded + } + + return uri +} + +func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.ConnectionConfig { + attempts := []connection.ConnectionConfig{config} + replicaUser := strings.TrimSpace(config.MongoReplicaUser) + if replicaUser == "" { + return attempts + } + if replicaUser == strings.TrimSpace(config.User) && config.MongoReplicaPassword == config.Password { + return attempts + } + + replicaConfig := config + replicaConfig.URI = "" + replicaConfig.User = replicaUser + replicaConfig.Password = config.MongoReplicaPassword + attempts = append(attempts, replicaConfig) + return attempts +} + +func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { + runConfig := applyMongoURI(config) + connectConfig := runConfig + + if runConfig.UseSSH && runConfig.MongoSRV { + return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道") + } + + if runConfig.UseSSH { + seeds := collectMongoSeeds(runConfig) + if len(seeds) == 0 { + seeds = append(seeds, normalizeMongoAddress(runConfig.Host, runConfig.Port)) + } + targetHost, targetPort, ok := parseHostPortWithDefault(seeds[0], defaultMongoPort) + if !ok { + return fmt.Errorf("MongoDB 连接失败:无效地址 %s", seeds[0]) + } + + logger.Infof("MongoDB 使用 SSH 连接:地址=%s:%d", targetHost, targetPort) + + forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, targetHost, targetPort) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + m.forwarder = forwarder + + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + localConfig := runConfig + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + localConfig.URI = "" + localConfig.Hosts = []string{normalizeMongoAddress(host, port)} + connectConfig = localConfig + logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) + } + + m.pingTimeout = getConnectTimeout(connectConfig) + m.database = connectConfig.Database + if m.database == "" { + m.database = "admin" + } + + sslAttempts := []connection.ConnectionConfig{connectConfig} + if shouldTrySSLPreferredFallback(connectConfig) { + sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig)) + } + + var errorDetails []string + for sslIndex, sslConfig := range sslAttempts { + sslLabel := "SSL" + if sslIndex > 0 { + sslLabel = "明文回退" + } + + attemptConfigs := buildMongoAuthAttempts(sslConfig) + for index, attemptConfig := range attemptConfigs { + authLabel := "主库凭据" + if index > 0 { + authLabel = "从库凭据" + } + + if sslIndex > 0 { + attemptConfig.URI = "" + } + uri := m.getURI(attemptConfig) + clientOpts := options.Client().ApplyURI(uri) + tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig) + if tlsEnabled { + clientOpts.SetTLSConfig(&tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: tlsInsecure, + }) + } + if attemptConfig.UseProxy { + clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy}) + } + connectCtx, connectCancel := context.WithTimeout(context.Background(), m.pingTimeout) + client, err := mongo.Connect(connectCtx, clientOpts) + connectCancel() + if err != nil { + errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)) + continue + } + + m.client = client + if err := m.Ping(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + _ = client.Disconnect(ctx) + cancel() + m.client = nil + errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)) + continue + } + if sslIndex > 0 { + logger.Warnf("MongoDB(v1) SSL 优先连接失败,已回退至明文连接") + } + return nil + } + } + + if len(errorDetails) > 0 { + return fmt.Errorf("MongoDB 连接失败:%s", strings.Join(errorDetails, ";")) + } + + return fmt.Errorf("MongoDB 连接失败:无可用连接方案") +} + +func (m *MongoDBV1) Close() error { + if m.forwarder != nil { + if err := m.forwarder.Close(); err != nil { + logger.Warnf("关闭 MongoDB SSH 端口转发失败:%v", err) + } + m.forwarder = nil + } + + if m.client != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return m.client.Disconnect(ctx) + } + return nil +} + +func (m *MongoDBV1) Ping() error { + if m.client == nil { + return fmt.Errorf("connection not open") + } + timeout := m.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return m.client.Ping(ctx, readpref.Primary()) +} + +func asMongoStringList(raw interface{}) []string { + values, ok := raw.(bson.A) + if !ok { + return nil + } + result := make([]string, 0, len(values)) + for _, entry := range values { + text := strings.TrimSpace(fmt.Sprintf("%v", entry)) + if text != "" { + result = append(result, text) + } + } + return result +} + +func asMongoString(raw interface{}) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return strings.TrimSpace(value) + } + return strings.TrimSpace(fmt.Sprintf("%v", raw)) +} + +func asMongoInt(raw interface{}) int { + switch value := raw.(type) { + case int: + return value + case int32: + return int(value) + case int64: + return int(value) + case float32: + return int(value) + case float64: + return int(value) + default: + return 0 + } +} + +func asMongoBool(raw interface{}) bool { + switch value := raw.(type) { + case bool: + return value + case int: + return value != 0 + case int32: + return value != 0 + case int64: + return value != 0 + case float32: + return value != 0 + case float64: + return value != 0 + default: + return false + } +} + +func asMongoInt64(raw interface{}) int64 { + switch value := raw.(type) { + case int: + return int64(value) + case int32: + return int64(value) + case int64: + return value + case float32: + return int64(value) + case float64: + return int64(value) + default: + return 0 + } +} + +func mongoStateByCode(code int) string { + switch code { + case 1: + return "PRIMARY" + case 2: + return "SECONDARY" + case 3: + return "RECOVERING" + case 5: + return "STARTUP2" + case 6: + return "UNKNOWN" + case 7: + return "ARBITER" + case 8: + return "DOWN" + case 9: + return "ROLLBACK" + case 10: + return "REMOVED" + default: + return "UNKNOWN" + } +} + +func normalizeMongoStateLabel(state string, stateCode int) string { + normalized := strings.ToUpper(strings.TrimSpace(state)) + if normalized != "" { + return normalized + } + return mongoStateByCode(stateCode) +} + +func buildMembersFromReplStatus(raw bson.M) []connection.MongoMemberInfo { + items, ok := raw["members"].(bson.A) + if !ok { + return nil + } + + members := make([]connection.MongoMemberInfo, 0, len(items)) + for _, entry := range items { + member, ok := entry.(bson.M) + if !ok { + continue + } + host := asMongoString(member["name"]) + if host == "" { + continue + } + stateCode := asMongoInt(member["state"]) + state := normalizeMongoStateLabel(asMongoString(member["stateStr"]), stateCode) + members = append(members, connection.MongoMemberInfo{ + Host: host, + Role: state, + State: state, + StateCode: stateCode, + Healthy: asMongoInt(member["health"]) > 0 || asMongoBool(member["health"]), + IsSelf: asMongoBool(member["self"]), + }) + } + + sort.Slice(members, func(i, j int) bool { + return members[i].Host < members[j].Host + }) + return members +} + +func buildMembersFromHello(raw bson.M) []connection.MongoMemberInfo { + hosts := asMongoStringList(raw["hosts"]) + if len(hosts) == 0 { + return nil + } + primary := asMongoString(raw["primary"]) + selfHost := asMongoString(raw["me"]) + passiveSet := make(map[string]struct{}) + for _, host := range asMongoStringList(raw["passives"]) { + passiveSet[host] = struct{}{} + } + arbiterSet := make(map[string]struct{}) + for _, host := range asMongoStringList(raw["arbiters"]) { + arbiterSet[host] = struct{}{} + } + + members := make([]connection.MongoMemberInfo, 0, len(hosts)) + for _, host := range hosts { + state := "SECONDARY" + stateCode := 2 + if host == primary { + state = "PRIMARY" + stateCode = 1 + } else if _, ok := arbiterSet[host]; ok { + state = "ARBITER" + stateCode = 7 + } else if _, ok := passiveSet[host]; ok { + state = "PASSIVE" + stateCode = 6 + } + members = append(members, connection.MongoMemberInfo{ + Host: host, + Role: state, + State: state, + StateCode: stateCode, + Healthy: true, + IsSelf: host == selfHost, + }) + } + + sort.Slice(members, func(i, j int) bool { + return members[i].Host < members[j].Host + }) + return members +} + +func (m *MongoDBV1) DiscoverMembers() (string, []connection.MongoMemberInfo, error) { + if m.client == nil { + return "", nil, fmt.Errorf("connection not open") + } + + timeout := m.pingTimeout + if timeout <= 0 { + timeout = 10 * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + adminDB := m.client.Database("admin") + + var replStatus bson.M + replErr := adminDB.RunCommand(ctx, bson.D{{Key: "replSetGetStatus", Value: 1}}).Decode(&replStatus) + if replErr == nil { + replicaSet := asMongoString(replStatus["set"]) + members := buildMembersFromReplStatus(replStatus) + if len(members) > 0 { + return replicaSet, members, nil + } + } + + var helloResult bson.M + helloErr := adminDB.RunCommand(ctx, bson.D{{Key: "hello", Value: 1}}).Decode(&helloResult) + if helloErr != nil { + if err := adminDB.RunCommand(ctx, bson.D{{Key: "isMaster", Value: 1}}).Decode(&helloResult); err != nil { + if replErr != nil { + return "", nil, fmt.Errorf("成员发现失败:replSetGetStatus=%v;hello=%v", replErr, err) + } + return "", nil, fmt.Errorf("成员发现失败:hello=%w", err) + } + } + + replicaSet := asMongoString(helloResult["setName"]) + members := buildMembersFromHello(helloResult) + if len(members) == 0 { + if replErr != nil { + return replicaSet, nil, fmt.Errorf("未获取到成员信息:replSetGetStatus=%v", replErr) + } + return replicaSet, nil, fmt.Errorf("未获取到成员信息") + } + return replicaSet, members, nil +} + +// Query executes a MongoDB command and returns results +// Supports JSON format commands like: {"find": "collection", "filter": {}} +func (m *MongoDBV1) Query(query string) ([]map[string]interface{}, []string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + return m.queryWithContext(ctx, query) +} + +// QueryContext executes a MongoDB command with the given context for timeout control +func (m *MongoDBV1) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + return m.queryWithContext(ctx, query) +} + +// sqlToMongoFind 将前端生成的简单 SQL 转换为 MongoDB find 命令 JSON。 +// 支持:SELECT * FROM "coll" LIMIT n OFFSET m / SELECT COUNT(*) as total FROM "coll" +func sqlToMongoFind(sql string) (string, bool) { + lower := strings.ToLower(strings.TrimSpace(sql)) + + // SELECT COUNT(*) as total FROM "coll" ... + if strings.HasPrefix(lower, "select count(") { + coll := extractCollectionFromSQL(sql) + if coll == "" { + return "", false + } + return fmt.Sprintf(`{"count":"%s","query":{}}`, coll), true + } + + // SELECT * FROM "coll" ... LIMIT n OFFSET m + if !strings.HasPrefix(lower, "select") { + return "", false + } + coll := extractCollectionFromSQL(sql) + if coll == "" { + return "", false + } + + limit := int64(0) + skip := int64(0) + + // 提取 LIMIT + if idx := strings.Index(lower, "limit "); idx >= 0 { + after := strings.TrimSpace(lower[idx+6:]) + parts := strings.Fields(after) + if len(parts) > 0 { + if n, err := strconv.ParseInt(parts[0], 10, 64); err == nil { + limit = n + } + } + } + + // 提取 OFFSET + if idx := strings.Index(lower, "offset "); idx >= 0 { + after := strings.TrimSpace(lower[idx+7:]) + parts := strings.Fields(after) + if len(parts) > 0 { + if n, err := strconv.ParseInt(parts[0], 10, 64); err == nil { + skip = n + } + } + } + + cmd := fmt.Sprintf(`{"find":"%s","filter":{}`, coll) + if limit > 0 { + cmd += fmt.Sprintf(`,"limit":%d`, limit) + } + if skip > 0 { + cmd += fmt.Sprintf(`,"skip":%d`, skip) + } + cmd += "}" + return cmd, true +} + +// extractCollectionFromSQL 从 SQL 中提取 FROM 后的 collection 名称。 +func extractCollectionFromSQL(sql string) string { + lower := strings.ToLower(sql) + idx := strings.Index(lower, "from ") + if idx < 0 { + return "" + } + after := strings.TrimSpace(sql[idx+5:]) + + // 去掉引号包裹 + var coll string + if len(after) > 0 && after[0] == '"' { + end := strings.Index(after[1:], "\"") + if end < 0 { + return "" + } + coll = after[1 : end+1] + } else if len(after) > 0 && after[0] == '`' { + end := strings.Index(after[1:], "`") + if end < 0 { + return "" + } + coll = after[1 : end+1] + } else { + parts := strings.Fields(after) + if len(parts) == 0 { + return "" + } + coll = parts[0] + } + return strings.TrimSpace(coll) +} + +func (m *MongoDBV1) queryWithContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if m.client == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + query = strings.TrimSpace(query) + if query == "" { + return nil, nil, fmt.Errorf("empty query") + } + + // 如果输入是 SQL 语句(前端 DataViewer 统一生成),自动转换为 MongoDB JSON 命令 + lowerQuery := strings.ToLower(query) + if strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") { + if converted, ok := sqlToMongoFind(query); ok { + query = converted + } + } + + // Parse JSON command + var cmd bson.D + if err := bson.UnmarshalExtJSON([]byte(query), true, &cmd); err != nil { + return nil, nil, fmt.Errorf("invalid JSON command: %w", err) + } + + // 对 find 和 count 命令使用原生 driver API,避免 RunCommand 的 firstBatch 限制 + if len(cmd) > 0 { + switch cmd[0].Key { + case "find": + return m.execFind(ctx, cmd) + case "count": + return m.execCount(ctx, cmd) + } + } + + // 其他命令走 RunCommand + db := m.client.Database(m.database) + var result bson.M + if err := db.RunCommand(ctx, cmd).Decode(&result); err != nil { + return nil, nil, err + } + + // Handle COUNT result (e.g. delete/update returns "n") + if n, ok := result["n"]; ok { + if _, hasCursor := result["cursor"]; !hasCursor { + return []map[string]interface{}{{"total": n}}, []string{"total"}, nil + } + } + + // Convert result to standard format + data := []map[string]interface{}{{"result": result}} + columns := []string{"result"} + + // If result contains cursor with documents, extract them + if cursor, ok := result["cursor"].(bson.M); ok { + if batch, ok := cursor["firstBatch"].(bson.A); ok { + data = make([]map[string]interface{}, 0, len(batch)) + columnSet := make(map[string]bool) + for _, doc := range batch { + if docMap, ok := doc.(bson.M); ok { + row := make(map[string]interface{}) + for k, v := range docMap { + row[k] = v + columnSet[k] = true + } + data = append(data, row) + } + } + columns = make([]string, 0, len(columnSet)) + for k := range columnSet { + columns = append(columns, k) + } + } + } + + return data, columns, nil +} + +// execFind 使用原生 Collection.Find() 执行查询,正确处理游标迭代 +func (m *MongoDBV1) execFind(ctx context.Context, cmd bson.D) ([]map[string]interface{}, []string, error) { + var collName string + var filter interface{} + var limit int64 + var skip int64 + var sortDoc interface{} + var projection interface{} + + for _, elem := range cmd { + switch elem.Key { + case "find": + collName = fmt.Sprintf("%v", elem.Value) + case "filter": + filter = elem.Value + case "limit": + limit = asMongoInt64(elem.Value) + case "skip": + skip = asMongoInt64(elem.Value) + case "sort": + sortDoc = elem.Value + case "projection": + projection = elem.Value + } + } + + if collName == "" { + return nil, nil, fmt.Errorf("find command missing collection name") + } + if filter == nil { + filter = bson.D{} + } + + collection := m.client.Database(m.database).Collection(collName) + opts := options.Find() + if limit > 0 { + opts.SetLimit(limit) + } + if skip > 0 { + opts.SetSkip(skip) + } + if sortDoc != nil { + opts.SetSort(sortDoc) + } + if projection != nil { + opts.SetProjection(projection) + } + + cursor, err := collection.Find(ctx, filter, opts) + if err != nil { + return nil, nil, err + } + defer cursor.Close(ctx) + + var data []map[string]interface{} + columnSet := make(map[string]bool) + + for cursor.Next(ctx) { + var doc bson.M + if err := cursor.Decode(&doc); err != nil { + continue + } + row := make(map[string]interface{}) + for k, v := range doc { + row[k] = convertBsonValue(v) + columnSet[k] = true + } + data = append(data, row) + } + + if err := cursor.Err(); err != nil { + return nil, nil, err + } + + columns := make([]string, 0, len(columnSet)) + for k := range columnSet { + columns = append(columns, k) + } + sort.Strings(columns) + + // 将 _id 列置首 + for i, col := range columns { + if col == "_id" && i > 0 { + columns = append(columns[:i], columns[i+1:]...) + columns = append([]string{"_id"}, columns...) + break + } + } + + return data, columns, nil +} + +// execCount 使用原生 Collection.CountDocuments() 执行计数 +func (m *MongoDBV1) execCount(ctx context.Context, cmd bson.D) ([]map[string]interface{}, []string, error) { + var collName string + var filter interface{} + + for _, elem := range cmd { + switch elem.Key { + case "count": + collName = fmt.Sprintf("%v", elem.Value) + case "query": + filter = elem.Value + } + } + + if collName == "" { + return nil, nil, fmt.Errorf("count command missing collection name") + } + if filter == nil { + filter = bson.D{} + } + + collection := m.client.Database(m.database).Collection(collName) + n, err := collection.CountDocuments(ctx, filter) + if err != nil { + return nil, nil, err + } + + return []map[string]interface{}{{"total": n}}, []string{"total"}, nil +} + +// convertBsonValue 将 BSON 特殊类型转换为前端可读的 JSON 友好值 +func convertBsonValue(v interface{}) interface{} { + switch val := v.(type) { + case primitive.ObjectID: + return val.Hex() + case bson.M: + result := make(map[string]interface{}, len(val)) + for k, v2 := range val { + result[k] = convertBsonValue(v2) + } + return result + case bson.D: + result := make(map[string]interface{}, len(val)) + for _, elem := range val { + result[elem.Key] = convertBsonValue(elem.Value) + } + return result + case bson.A: + result := make([]interface{}, len(val)) + for i, v2 := range val { + result[i] = convertBsonValue(v2) + } + return result + default: + return v + } +} + +func (m *MongoDBV1) Exec(query string) (int64, error) { + _, _, err := m.Query(query) + if err != nil { + return 0, err + } + return 1, nil +} + +// ExecContext executes a MongoDB command with the given context for timeout control +func (m *MongoDBV1) ExecContext(ctx context.Context, query string) (int64, error) { + _, _, err := m.QueryContext(ctx, query) + if err != nil { + return 0, err + } + return 1, nil +} + +func (m *MongoDBV1) GetDatabases() ([]string, error) { + if m.client == nil { + return nil, fmt.Errorf("connection not open") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + dbs, err := m.client.ListDatabaseNames(ctx, bson.M{}) + if err != nil { + return nil, err + } + return dbs, nil +} + +func (m *MongoDBV1) GetTables(dbName string) ([]string, error) { + if m.client == nil { + return nil, fmt.Errorf("connection not open") + } + + targetDB := dbName + if targetDB == "" { + targetDB = m.database + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + collections, err := m.client.Database(targetDB).ListCollectionNames(ctx, bson.M{}) + if err != nil { + return nil, err + } + return collections, nil +} + +func (m *MongoDBV1) GetCreateStatement(dbName, tableName string) (string, error) { + return fmt.Sprintf("// MongoDB collection: %s.%s\n// MongoDB is schemaless - no CREATE statement available", dbName, tableName), nil +} + +// GetColumns returns empty for MongoDB (schemaless) +func (m *MongoDBV1) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + // MongoDB is schemaless, return empty + return []connection.ColumnDefinition{}, nil +} + +// GetAllColumns returns empty for MongoDB (schemaless) +func (m *MongoDBV1) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + return []connection.ColumnDefinitionWithTable{}, nil +} + +// GetIndexes returns indexes for a MongoDB collection +func (m *MongoDBV1) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + if m.client == nil { + return nil, fmt.Errorf("connection not open") + } + + targetDB := dbName + if targetDB == "" { + targetDB = m.database + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + collection := m.client.Database(targetDB).Collection(tableName) + cursor, err := collection.Indexes().List(ctx) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var indexes []connection.IndexDefinition + for cursor.Next(ctx) { + var idx bson.M + if err := cursor.Decode(&idx); err != nil { + continue + } + + name := fmt.Sprintf("%v", idx["name"]) + unique := false + if u, ok := idx["unique"].(bool); ok { + unique = u + } + + // Extract key fields + if key, ok := idx["key"].(bson.M); ok { + seq := 1 + for field := range key { + nonUnique := 1 + if unique { + nonUnique = 0 + } + indexes = append(indexes, connection.IndexDefinition{ + Name: name, + ColumnName: field, + NonUnique: nonUnique, + SeqInIndex: seq, + IndexType: "BTREE", + }) + seq++ + } + } + } + + return indexes, nil +} + +func (m *MongoDBV1) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + // MongoDB doesn't have foreign keys + return []connection.ForeignKeyDefinition{}, nil +} + +func (m *MongoDBV1) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + // MongoDB doesn't have triggers in the traditional sense + return []connection.TriggerDefinition{}, nil +} + +// ApplyChanges implements batch changes for MongoDB +func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if m.client == nil { + return fmt.Errorf("connection not open") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + collection := m.client.Database(m.database).Collection(tableName) + + // Process deletes + for _, pk := range changes.Deletes { + filter := bson.M{} + for k, v := range pk { + filter[k] = v + } + if len(filter) > 0 { + if _, err := collection.DeleteOne(ctx, filter); err != nil { + return fmt.Errorf("delete error: %v", err) + } + } + } + + // Process updates + for _, update := range changes.Updates { + filter := bson.M{} + for k, v := range update.Keys { + filter[k] = v + } + if len(filter) == 0 { + return fmt.Errorf("update requires keys") + } + + updateDoc := bson.M{"$set": bson.M{}} + for k, v := range update.Values { + updateDoc["$set"].(bson.M)[k] = v + } + + if _, err := collection.UpdateOne(ctx, filter, updateDoc); err != nil { + return fmt.Errorf("update error: %v", err) + } + } + + // Process inserts + for _, row := range changes.Inserts { + doc := bson.M{} + for k, v := range row { + doc[k] = v + } + if len(doc) > 0 { + if _, err := collection.InsertOne(ctx, doc); err != nil { + return fmt.Errorf("insert error: %v", err) + } + } + } + + return nil +} diff --git a/internal/db/mysql_agent_path.go b/internal/db/mysql_agent_path.go index d5a79f7..bc5a76c 100644 --- a/internal/db/mysql_agent_path.go +++ b/internal/db/mysql_agent_path.go @@ -35,6 +35,32 @@ func ResolveOptionalDriverAgentExecutablePath(downloadDir string, driverType str return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil } +func ResolveOptionalDriverAgentExecutablePathForVersion(downloadDir string, driverType string, version string) (string, error) { + normalized := normalizeRuntimeDriverType(driverType) + if strings.TrimSpace(normalized) == "" { + return "", fmt.Errorf("驱动类型为空") + } + root, err := resolveExternalDriverRoot(downloadDir) + if err != nil { + return "", err + } + + if normalized != "mongodb" { + return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil + } + + baseName := optionalDriverAgentExecutableName(normalized) + ext := filepath.Ext(baseName) + stem := strings.TrimSuffix(baseName, ext) + major := 2 + trimmed := strings.TrimSpace(version) + trimmed = strings.TrimPrefix(trimmed, "v") + if strings.HasPrefix(trimmed, "1.") || trimmed == "1" { + major = 1 + } + versionedName := fmt.Sprintf("%s-v%d%s", stem, major, ext) + return filepath.Join(root, normalized, versionedName), nil +} func ResolveMySQLAgentExecutablePath(downloadDir string) (string, error) { return ResolveOptionalDriverAgentExecutablePath(downloadDir, "mysql") } diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 2c6a332..4aefa29 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -184,9 +184,10 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string { } timeout := getConnectTimeoutSeconds(config) + tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds", - config.User, config.Password, protocol, address, database, timeout) + return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) } func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 727e82c..f0d03eb 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -35,12 +35,23 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { } u.User = url.UserPassword(config.User, config.Password) u.RawPath = "/" + url.PathEscape(database) + q := url.Values{} + switch normalizedSSLMode(config) { + case sslModeRequired: + q.Set("SSL", "TRUE") + q.Set("SSL VERIFY", "TRUE") + case sslModeSkipVerify, sslModePreferred: + q.Set("SSL", "TRUE") + q.Set("SSL VERIFY", "FALSE") + } + if encoded := q.Encode(); encoded != "" { + u.RawQuery = encoded + } return u.String() } func (o *OracleDB) Connect(config connection.ConnectionConfig) error { - var dsn string - var err error + runConfig := config serviceName := strings.TrimSpace(config.Database) if serviceName == "" { return fmt.Errorf("Oracle 连接缺少服务名(Service Name),请在连接配置中填写,例如 ORCLPDB1") @@ -73,22 +84,37 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error { localConfig.Port = port localConfig.UseSSH = false - dsn = o.getDSN(localConfig) + runConfig = localConfig logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) - } else { - dsn = o.getDSN(config) } - db, err := sql.Open("oracle", dsn) - if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) + attempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + attempts = append(attempts, withSSLDisabled(runConfig)) } - o.conn = db - o.pingTimeout = getConnectTimeout(config) - if err := o.Ping(); err != nil { - return fmt.Errorf("连接建立后验证失败:%w", err) + + var failures []string + for idx, attempt := range attempts { + dsn := o.getDSN(attempt) + db, err := sql.Open("oracle", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) + continue + } + o.conn = db + o.pingTimeout = getConnectTimeout(attempt) + if err := o.Ping(); err != nil { + _ = db.Close() + o.conn = nil + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + continue + } + if idx > 0 { + logger.Warnf("Oracle SSL 优先连接失败,已回退至明文连接") + } + return nil } - return nil + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (o *OracleDB) Close() error { diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index 7727773..b6cd82e 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -62,7 +62,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string { } u.User = url.UserPassword(config.User, config.Password) q := url.Values{} - q.Set("sslmode", "disable") + q.Set("sslmode", resolvePostgresSSLMode(config)) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) u.RawQuery = q.Encode() @@ -126,34 +126,49 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error { logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) } - attemptDBs := resolvePostgresConnectDatabases(runConfig) + sslAttempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + sslAttempts = append(sslAttempts, withSSLDisabled(runConfig)) + } + var failures []string - for _, dbName := range attemptDBs { - attemptConfig := runConfig - attemptConfig.Database = dbName - dsn := p.getDSN(attemptConfig) - - dbConn, err := sql.Open("postgres", dsn) - if err != nil { - failures = append(failures, fmt.Sprintf("数据库=%s 打开连接失败: %v", dbName, err)) - continue - } - p.conn = dbConn - - // Force verification - if err := p.Ping(); err != nil { - failures = append(failures, fmt.Sprintf("数据库=%s 验证失败: %v", dbName, err)) - _ = dbConn.Close() - p.conn = nil - continue + for sslIndex, sslConfig := range sslAttempts { + sslLabel := "SSL" + if sslIndex > 0 { + sslLabel = "明文回退" } - if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") { - logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName) - } + attemptDBs := resolvePostgresConnectDatabases(sslConfig) + for _, dbName := range attemptDBs { + attemptConfig := sslConfig + attemptConfig.Database = dbName + dsn := p.getDSN(attemptConfig) - cleanupOnFailure = false - return nil + dbConn, err := sql.Open("postgres", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("%s 数据库=%s 打开连接失败: %v", sslLabel, dbName, err)) + continue + } + p.conn = dbConn + + // Force verification + if err := p.Ping(); err != nil { + failures = append(failures, fmt.Sprintf("%s 数据库=%s 验证失败: %v", sslLabel, dbName, err)) + _ = dbConn.Close() + p.conn = nil + continue + } + + if sslIndex > 0 { + logger.Warnf("PostgreSQL SSL 优先连接失败,已回退至明文连接") + } + if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") { + logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName) + } + + cleanupOnFailure = false + return nil + } } if len(failures) == 0 { diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index bac36e9..a0458ab 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -47,8 +47,9 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string { q := url.Values{} q.Set("database", dbname) q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) - q.Set("encrypt", "disable") - q.Set("TrustServerCertificate", "true") + encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config) + q.Set("encrypt", encrypt) + q.Set("TrustServerCertificate", trustServerCertificate) u.RawQuery = q.Encode() return u.String() diff --git a/internal/db/ssl_mode.go b/internal/db/ssl_mode.go new file mode 100644 index 0000000..050db53 --- /dev/null +++ b/internal/db/ssl_mode.go @@ -0,0 +1,122 @@ +package db + +import ( + "crypto/tls" + "strings" + + "GoNavi-Wails/internal/connection" +) + +const ( + sslModeDisable = "disable" + sslModePreferred = "preferred" + sslModeRequired = "required" + sslModeSkipVerify = "skip-verify" +) + +func normalizeSSLModeValue(raw string) string { + mode := strings.ToLower(strings.TrimSpace(raw)) + switch mode { + case "", sslModePreferred, "prefer": + return sslModePreferred + case sslModeRequired, "require", "on", "true", "mandatory", "strict": + return sslModeRequired + case sslModeSkipVerify, "insecure", "skipverify", "skip_verify", "insecure-skip-verify": + return sslModeSkipVerify + case sslModeDisable, "disabled", "off", "false", "none": + return sslModeDisable + default: + return sslModePreferred + } +} + +func normalizedSSLMode(config connection.ConnectionConfig) string { + if !config.UseSSL { + return sslModeDisable + } + return normalizeSSLModeValue(config.SSLMode) +} + +func shouldTrySSLPreferredFallback(config connection.ConnectionConfig) bool { + return config.UseSSL && normalizeSSLModeValue(config.SSLMode) == sslModePreferred +} + +func withSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig { + next := config + next.UseSSL = false + next.SSLMode = sslModeDisable + return next +} + +func resolveMySQLTLSMode(config connection.ConnectionConfig) string { + switch normalizedSSLMode(config) { + case sslModeDisable: + return "false" + case sslModeRequired: + return "true" + case sslModeSkipVerify: + return "skip-verify" + default: + return "preferred" + } +} + +func resolvePostgresSSLMode(config connection.ConnectionConfig) string { + switch normalizedSSLMode(config) { + case sslModeDisable: + return "disable" + case sslModeRequired: + return "require" + case sslModeSkipVerify: + return "require" + default: + return "require" + } +} + +func resolveSQLServerTLSSettings(config connection.ConnectionConfig) (encrypt string, trustServerCertificate string) { + switch normalizedSSLMode(config) { + case sslModeDisable: + return "disable", "true" + case sslModeRequired: + return "true", "false" + case sslModeSkipVerify: + return "true", "true" + default: + return "false", "true" + } +} + +func resolveGenericTLSConfig(config connection.ConnectionConfig) *tls.Config { + switch normalizedSSLMode(config) { + case sslModeDisable: + return nil + case sslModeRequired: + return &tls.Config{MinVersion: tls.VersionTLS12} + case sslModeSkipVerify: + return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + default: + // Preferred: 先尝试 TLS(为提升兼容性默认跳过证书校验),失败时由调用方按需回退明文。 + return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + } +} + +func resolveMongoTLSSettings(config connection.ConnectionConfig) (enabled bool, insecure bool) { + switch normalizedSSLMode(config) { + case sslModeDisable: + return false, false + case sslModeRequired: + return true, false + case sslModeSkipVerify: + return true, true + default: + return true, true + } +} + +func resolveTDengineNet(config connection.ConnectionConfig) string { + if normalizedSSLMode(config) == sslModeDisable { + return "ws" + } + return "wss" +} diff --git a/internal/db/tdengine_impl.go b/internal/db/tdengine_impl.go index 640c97f..300cfb0 100644 --- a/internal/db/tdengine_impl.go +++ b/internal/db/tdengine_impl.go @@ -40,11 +40,12 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string { path = "/" + dbName } - return fmt.Sprintf("%s:%s@ws(%s)%s", user, pass, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path) + netType := resolveTDengineNet(config) + return fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path) } func (t *TDengineDB) Connect(config connection.ConnectionConfig) error { - var dsn string + runConfig := config if config.UseSSH { logger.Infof("TDengine 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) @@ -68,23 +69,38 @@ func (t *TDengineDB) Connect(config connection.ConnectionConfig) error { localConfig.Host = host localConfig.Port = port localConfig.UseSSH = false - dsn = t.getDSN(localConfig) + runConfig = localConfig logger.Infof("TDengine 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) - } else { - dsn = t.getDSN(config) } - db, err := sql.Open("taosWS", dsn) - if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) + attempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + attempts = append(attempts, withSSLDisabled(runConfig)) } - t.conn = db - t.pingTimeout = getConnectTimeout(config) - if err := t.Ping(); err != nil { - return fmt.Errorf("连接建立后验证失败:%w", err) + var failures []string + for idx, attempt := range attempts { + dsn := t.getDSN(attempt) + db, err := sql.Open("taosWS", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) + continue + } + t.conn = db + t.pingTimeout = getConnectTimeout(attempt) + + if err := t.Ping(); err != nil { + _ = db.Close() + t.conn = nil + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + continue + } + if idx > 0 { + logger.Warnf("TDengine SSL 优先连接失败,已回退至明文连接") + } + return nil } - return nil + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (t *TDengineDB) Close() error { diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go index 250fe73..971a171 100644 --- a/internal/db/vastbase_impl.go +++ b/internal/db/vastbase_impl.go @@ -41,7 +41,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string { } u.User = url.UserPassword(config.User, config.Password) q := url.Values{} - q.Set("sslmode", "disable") + q.Set("sslmode", resolvePostgresSSLMode(config)) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) u.RawQuery = q.Encode() @@ -49,7 +49,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string { } func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error { - var dsn string + runConfig := config if config.UseSSH { logger.Infof("Vastbase 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) @@ -75,23 +75,37 @@ func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error { localConfig.Port = port localConfig.UseSSH = false - dsn = v.getDSN(localConfig) + runConfig = localConfig logger.Infof("Vastbase 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) - } else { - dsn = v.getDSN(config) } - db, err := sql.Open("postgres", dsn) - if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) + attempts := []connection.ConnectionConfig{runConfig} + if shouldTrySSLPreferredFallback(runConfig) { + attempts = append(attempts, withSSLDisabled(runConfig)) } - v.conn = db - v.pingTimeout = getConnectTimeout(config) - if err := v.Ping(); err != nil { - return fmt.Errorf("连接建立后验证失败:%w", err) + var failures []string + for idx, attempt := range attempts { + dsn := v.getDSN(attempt) + db, err := sql.Open("postgres", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) + continue + } + v.conn = db + v.pingTimeout = getConnectTimeout(attempt) + if err := v.Ping(); err != nil { + _ = db.Close() + v.conn = nil + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + continue + } + if idx > 0 { + logger.Warnf("Vastbase SSL 优先连接失败,已回退至明文连接") + } + return nil } - return nil + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (v *VastbaseDB) Close() error { diff --git a/internal/redis/redis_impl.go b/internal/redis/redis_impl.go index f08b4f5..8d41a28 100644 --- a/internal/redis/redis_impl.go +++ b/internal/redis/redis_impl.go @@ -2,6 +2,7 @@ package redis import ( "context" + "crypto/tls" "fmt" "net" "strconv" @@ -201,25 +202,48 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error { timeout := normalizeRedisTimeout(config.Timeout) if r.isCluster { - opts := &redis.ClusterOptions{ - Addrs: seedAddrs, - Username: strings.TrimSpace(config.User), - Password: config.Password, - DialTimeout: timeout, - ReadTimeout: timeout, - WriteTimeout: timeout, + attempts := []connection.ConnectionConfig{config} + if shouldTryRedisSSLPreferredFallback(config) { + attempts = append(attempts, withRedisSSLDisabled(config)) } - clusterClient := redis.NewClusterClient(opts) - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if err := clusterClient.Ping(ctx).Err(); err != nil { - clusterClient.Close() - return fmt.Errorf("Redis 集群连接失败: %w", err) + + var failures []string + for idx, attempt := range attempts { + var tlsConfig *tls.Config + if cfg := resolveRedisTLSConfig(attempt); cfg != nil { + if host, _, err := net.SplitHostPort(seedAddrs[0]); err == nil && host != "" { + cfg.ServerName = host + } + tlsConfig = cfg + } + opts := &redis.ClusterOptions{ + Addrs: seedAddrs, + Username: strings.TrimSpace(attempt.User), + Password: attempt.Password, + DialTimeout: timeout, + ReadTimeout: timeout, + WriteTimeout: timeout, + TLSConfig: tlsConfig, + } + clusterClient := redis.NewClusterClient(opts) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + pingErr := clusterClient.Ping(ctx).Err() + cancel() + if pingErr != nil { + clusterClient.Close() + failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr)) + continue + } + r.client = clusterClient + r.clusterClient = clusterClient + r.config = attempt + if idx > 0 { + logger.Warnf("Redis 集群 SSL 优先连接失败,已回退至明文连接") + } + logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB) + return nil } - r.client = clusterClient - r.clusterClient = clusterClient - logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB) - return nil + return fmt.Errorf("Redis 集群连接失败: %s", strings.Join(failures, ";")) } addr := seedAddrs[0] @@ -233,29 +257,53 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error { logger.Infof("Redis 通过 SSH 隧道连接: %s -> %s:%d", addr, config.Host, config.Port) } - opts := &redis.Options{ - Addr: addr, - Username: strings.TrimSpace(config.User), - Password: config.Password, - DB: r.currentDB, - DialTimeout: timeout, - ReadTimeout: timeout, - WriteTimeout: timeout, + attempts := []connection.ConnectionConfig{config} + if shouldTryRedisSSLPreferredFallback(config) { + attempts = append(attempts, withRedisSSLDisabled(config)) } - singleClient := redis.NewClient(opts) - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() + var failures []string + for idx, attempt := range attempts { + var tlsConfig *tls.Config + if cfg := resolveRedisTLSConfig(attempt); cfg != nil { + if host, _, err := net.SplitHostPort(addr); err == nil && host != "" { + cfg.ServerName = host + } + tlsConfig = cfg + } - if err := singleClient.Ping(ctx).Err(); err != nil { - singleClient.Close() - return fmt.Errorf("Redis 连接失败: %w", err) + opts := &redis.Options{ + Addr: addr, + Username: strings.TrimSpace(attempt.User), + Password: attempt.Password, + DB: r.currentDB, + DialTimeout: timeout, + ReadTimeout: timeout, + WriteTimeout: timeout, + TLSConfig: tlsConfig, + } + + singleClient := redis.NewClient(opts) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + pingErr := singleClient.Ping(ctx).Err() + cancel() + if pingErr != nil { + singleClient.Close() + failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr)) + continue + } + + r.client = singleClient + r.singleClient = singleClient + r.config = attempt + if idx > 0 { + logger.Warnf("Redis SSL 优先连接失败,已回退至明文连接") + } + logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB) + return nil } - r.client = singleClient - r.singleClient = singleClient - logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB) - return nil + return fmt.Errorf("Redis 连接失败: %s", strings.Join(failures, ";")) } // Close closes the Redis connection diff --git a/internal/redis/ssl_mode.go b/internal/redis/ssl_mode.go new file mode 100644 index 0000000..11f3e9f --- /dev/null +++ b/internal/redis/ssl_mode.go @@ -0,0 +1,55 @@ +package redis + +import ( + "crypto/tls" + "strings" + + "GoNavi-Wails/internal/connection" +) + +func normalizeRedisSSLMode(raw string) string { + mode := strings.ToLower(strings.TrimSpace(raw)) + switch mode { + case "", "preferred", "prefer": + return "preferred" + case "required", "require", "on", "true", "mandatory", "strict": + return "required" + case "skip-verify", "insecure", "skipverify", "skip_verify", "insecure-skip-verify": + return "skip-verify" + case "disable", "disabled", "off", "false", "none": + return "disable" + default: + return "preferred" + } +} + +func redisSSLMode(config connection.ConnectionConfig) string { + if !config.UseSSL { + return "disable" + } + return normalizeRedisSSLMode(config.SSLMode) +} + +func shouldTryRedisSSLPreferredFallback(config connection.ConnectionConfig) bool { + return config.UseSSL && normalizeRedisSSLMode(config.SSLMode) == "preferred" +} + +func withRedisSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig { + next := config + next.UseSSL = false + next.SSLMode = "disable" + return next +} + +func resolveRedisTLSConfig(config connection.ConnectionConfig) *tls.Config { + switch redisSSLMode(config) { + case "disable": + return nil + case "required": + return &tls.Config{MinVersion: tls.VersionTLS12} + case "skip-verify": + return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + default: + return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + } +}