From 04f8b266d370a9e9afaa17b06de50784c0f9ba1c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 13:57:11 +0800 Subject: [PATCH] =?UTF-8?q?=20=20-=20feat(connection,metadata,kingbase):?= =?UTF-8?q?=20=E5=A2=9E=E5=BC=BA=E5=A4=9A=E6=95=B0=E6=8D=AE=E6=BA=90?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E8=83=BD=E5=8A=9B=E5=B9=B6=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E9=87=91=E4=BB=93/=E8=BE=BE=E6=A2=A6/Oracle/ClickHouse?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E6=80=A7=E9=97=AE=E9=A2=98=20(#188)=20(#190)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(http-tunnel): 支持独立 HTTP 隧道连接并覆盖多数据源 refs #168 * fix(kingbase-data-grid): 修复金仓打开表卡顿并降低对象渲染开销 refs #178 * fix(kingbase-transaction): 修复金仓事务提交重复引号导致语法错误 refs #176 * fix(driver-agent): 修复老版本 Win10 升级后金仓驱动代理启动失败 refs #177 * chore(ci): 新增手动触发的 macOS 测试构建工作流 * chore(ci): 允许测试工作流在当前分支自动触发 * fix(query-editor): 修复 SQL 编辑中光标随机跳到末尾 refs #185 * feat(data-sync): 增加差异 SQL 预览能力便于审核 refs #174 * fix(clickhouse-connect): 自动识别并回退 HTTP/Native 协议连接 refs #181 * fix(oracle-metadata): 修复视图与函数加载按 schema 过滤异常 refs #155 * fix(dameng-databases): 修复显示全部库时数据库列表不完整 refs #154 * fix(connection,db-list): 统一处理空列表返回并修复达梦连接测试报错 refs #157 Co-authored-by: 辣条 <69459608+tianqijiuyun-latiao@users.noreply.github.com> --- .github/workflows/test-macos-build.yml | 91 ++++++++ frontend/package.json.md5 | 2 +- frontend/src/components/ConnectionModal.tsx | 136 ++++++++++-- frontend/src/components/DataGrid.tsx | 31 ++- frontend/src/components/DataSyncModal.tsx | 231 +++++++++++++++++--- frontend/src/components/QueryEditor.tsx | 38 +++- frontend/src/components/Sidebar.tsx | 40 +++- frontend/src/store.ts | 16 +- frontend/src/types.ts | 9 + frontend/wailsjs/go/models.ts | 23 ++ internal/app/app.go | 11 + internal/app/db_proxy.go | 29 +++ internal/app/global_proxy.go | 2 +- internal/app/methods_db.go | 21 +- internal/app/methods_driver.go | 36 ++- internal/app/methods_redis.go | 47 +++- internal/connection/types.go | 72 +++--- internal/db/clickhouse_impl.go | 84 ++++++- internal/db/dameng_impl.go | 81 ++++++- internal/db/driver_agent_binary_check.go | 74 +++++++ internal/db/driver_support.go | 3 + internal/db/driver_support_test.go | 19 +- internal/db/kingbase_impl.go | 91 ++++++-- internal/db/kingbase_impl_test.go | 74 +++++++ internal/db/optional_driver_agent_impl.go | 29 +++ internal/db/query_value.go | 11 + internal/db/query_value_test.go | 30 +++ 27 files changed, 1162 insertions(+), 169 deletions(-) create mode 100644 .github/workflows/test-macos-build.yml create mode 100644 internal/db/driver_agent_binary_check.go create mode 100644 internal/db/kingbase_impl_test.go diff --git a/.github/workflows/test-macos-build.yml b/.github/workflows/test-macos-build.yml new file mode 100644 index 0000000..1dd01af --- /dev/null +++ b/.github/workflows/test-macos-build.yml @@ -0,0 +1,91 @@ +name: Test Build macOS (Manual) + +on: + workflow_dispatch: + inputs: + build_label: + description: "测试包标识(仅用于文件名)" + required: false + default: "test" + push: + branches: + - feature/kingbase_opt + paths: + - ".github/workflows/test-macos-build.yml" + +permissions: + contents: read + +jobs: + build-macos: + name: Build macOS ${{ matrix.arch }} + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + include: + - platform: darwin/amd64 + arch: amd64 + - platform: darwin/arm64 + arch: arm64 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.24.3" + check-latest: true + + - name: Setup Node + uses: actions/setup-node@v4 + with: + node-version: "20" + + - name: Install Wails + run: go install github.com/wailsapp/wails/v2/cmd/wails@v2.11.0 + + - name: Build App + run: | + set -euo pipefail + OUTPUT_NAME="gonavi-test-${{ matrix.arch }}" + BUILD_LABEL="${{ inputs.build_label }}" + if [ -z "$BUILD_LABEL" ]; then + BUILD_LABEL="test" + fi + APP_VERSION="${BUILD_LABEL}-${GITHUB_RUN_NUMBER}" + wails build \ + -platform "${{ matrix.platform }}" \ + -clean \ + -o "$OUTPUT_NAME" \ + -ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${APP_VERSION}" + + - name: Package Zip + run: | + set -euo pipefail + APP_PATH="build/bin/gonavi-test-${{ matrix.arch }}.app" + if [ ! -d "$APP_PATH" ]; then + APP_PATH=$(find build/bin -maxdepth 1 -name "*.app" | head -n 1 || true) + fi + if [ -z "$APP_PATH" ] || [ ! -d "$APP_PATH" ]; then + echo "未找到 .app 产物" + ls -la build/bin || true + exit 1 + fi + LABEL="${{ inputs.build_label }}" + if [ -z "$LABEL" ]; then + LABEL="test" + fi + ZIP_NAME="GoNavi-${LABEL}-macos-${{ matrix.arch }}-run${GITHUB_RUN_NUMBER}.zip" + mkdir -p artifacts + ditto -c -k --sequesterRsrc --keepParent "$APP_PATH" "artifacts/$ZIP_NAME" + shasum -a 256 "artifacts/$ZIP_NAME" > "artifacts/$ZIP_NAME.sha256" + + - name: Upload Artifact + uses: actions/upload-artifact@v4 + with: + name: gonavi-macos-${{ matrix.arch }}-run${{ github.run_number }} + path: artifacts/* + if-no-files-found: error diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index 0f8f4fe..a7661c0 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -5b8157374dae5f9340e31b2d0bd2c00e \ No newline at end of file +d0f9366af59a6367ad3c7e2d4185ead4 \ No newline at end of file diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 45ef1a8..85aa4c6 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -101,6 +101,7 @@ const ConnectionModal: React.FC<{ const [useSSL, setUseSSL] = useState(false); const [useSSH, setUseSSH] = useState(false); const [useProxy, setUseProxy] = useState(false); + const [useHttpTunnel, setUseHttpTunnel] = useState(false); const [dbType, setDbType] = useState('mysql'); const [step, setStep] = useState(1); // 1: Select Type, 2: Configure const [activeGroup, setActiveGroup] = useState(0); // Active category index in step 1 @@ -1026,6 +1027,8 @@ const ConnectionModal: React.FC<{ const mysqlIsReplica = String(config.topology || '').toLowerCase() === 'replica' || mysqlReplicaHosts.length > 0; const mongoIsReplica = String(config.topology || '').toLowerCase() === 'replica' || mongoHosts.length > 0 || !!config.replicaSet; const redisIsCluster = String(config.topology || '').toLowerCase() === 'cluster' || redisHosts.length > 0; + const hasHttpTunnel = !!config.useHttpTunnel; + const hasProxy = !hasHttpTunnel && !!config.useProxy; form.setFieldsValue({ type: configType, name: initialValues.name, @@ -1047,12 +1050,17 @@ const ConnectionModal: React.FC<{ sshUser: config.ssh?.user, sshPassword: config.ssh?.password, sshKeyPath: config.ssh?.keyPath, - useProxy: config.useProxy, + useProxy: hasProxy, proxyType: config.proxy?.type || 'socks5', proxyHost: config.proxy?.host, proxyPort: config.proxy?.port, proxyUser: config.proxy?.user, proxyPassword: config.proxy?.password, + useHttpTunnel: hasHttpTunnel, + httpTunnelHost: config.httpTunnel?.host, + httpTunnelPort: config.httpTunnel?.port || 8080, + httpTunnelUser: config.httpTunnel?.user, + httpTunnelPassword: config.httpTunnel?.password, driver: config.driver, dsn: config.dsn, timeout: config.timeout || 30, @@ -1076,7 +1084,8 @@ const ConnectionModal: React.FC<{ }); setUseSSL(!!config.useSSL); setUseSSH(config.useSSH || false); - setUseProxy(config.useProxy || false); + setUseProxy(hasProxy); + setUseHttpTunnel(hasHttpTunnel); setDbType(configType); // 如果是 Redis 编辑模式,设置已保存的 Redis 数据库列表 if (configType === 'redis') { @@ -1089,6 +1098,7 @@ const ConnectionModal: React.FC<{ setUseSSL(false); setUseSSH(false); setUseProxy(false); + setUseHttpTunnel(false); setDbType('mysql'); setActiveGroup(0); } @@ -1140,6 +1150,7 @@ const ConnectionModal: React.FC<{ setUseSSL(false); setUseSSH(false); setUseProxy(false); + setUseHttpTunnel(false); setDbType('mysql'); setStep(1); onClose(); @@ -1185,19 +1196,24 @@ const ConnectionModal: React.FC<{ ? await RedisConnect(config as any) : await TestConnection(config as any); - if (res.success) { - setTestResult({ type: 'success', message: res.message }); - if (isRedisType) { - setRedisDbList(Array.from({ length: 16 }, (_, i) => i)); - } else { - // Other databases: fetch database list - const dbRes = await DBGetDatabases(config as any); - if (dbRes.success) { - const dbs = (dbRes.data as any[]).map((row: any) => row.Database || row.database); - setDbList(dbs); - } - } - } else { + if (res.success) { + setTestResult({ type: 'success', message: res.message }); + if (isRedisType) { + setRedisDbList(Array.from({ length: 16 }, (_, i) => i)); + } else { + // Other databases: fetch database list + const dbRes = await DBGetDatabases(config as any); + if (dbRes.success) { + const dbRows = Array.isArray(dbRes.data) ? dbRes.data : []; + const dbs = dbRows + .map((row: any) => row?.Database || row?.database) + .filter((name: any) => typeof name === 'string' && name.trim() !== ''); + setDbList(dbs); + } else { + setDbList([]); + } + } + } else { const failMessage = buildTestFailureMessage( res?.message, '连接被拒绝或参数无效,请检查后重试' @@ -1388,7 +1404,8 @@ const ConnectionModal: React.FC<{ password: mergedValues.sshPassword || "", keyPath: mergedValues.sshKeyPath || "" } : { host: "", port: 22, user: "", password: "", keyPath: "" }; - const effectiveUseProxy = !isFileDbType && !!mergedValues.useProxy; + const effectiveUseHttpTunnel = !isFileDbType && !!mergedValues.useHttpTunnel; + const effectiveUseProxy = !isFileDbType && !!mergedValues.useProxy && !effectiveUseHttpTunnel; const proxyTypeRaw = String(mergedValues.proxyType || 'socks5').toLowerCase(); const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5'; const proxyConfig: NonNullable = effectiveUseProxy ? { @@ -1404,6 +1421,25 @@ const ConnectionModal: React.FC<{ user: '', password: '', }; + const httpTunnelConfig: NonNullable = effectiveUseHttpTunnel ? { + host: String(mergedValues.httpTunnelHost || '').trim(), + port: Number(mergedValues.httpTunnelPort || 8080), + user: String(mergedValues.httpTunnelUser || '').trim(), + password: mergedValues.httpTunnelPassword || "", + } : { + host: '', + port: 8080, + user: '', + password: '', + }; + if (effectiveUseHttpTunnel) { + if (!httpTunnelConfig.host) { + throw new Error('HTTP 隧道主机不能为空'); + } + if (!Number.isFinite(httpTunnelConfig.port) || httpTunnelConfig.port <= 0 || httpTunnelConfig.port > 65535) { + throw new Error('HTTP 隧道端口必须在 1-65535 之间'); + } + } const keepPassword = !forPersist || savePassword; @@ -1423,6 +1459,8 @@ const ConnectionModal: React.FC<{ ssh: sshConfig, useProxy: effectiveUseProxy, proxy: proxyConfig, + useHttpTunnel: effectiveUseHttpTunnel, + httpTunnel: httpTunnelConfig, driver: mergedValues.driver, dsn: mergedValues.dsn, timeout: Number(mergedValues.timeout || 30), @@ -1461,6 +1499,7 @@ const ConnectionModal: React.FC<{ setUseSSL(false); setUseSSH(false); setUseProxy(false); + setUseHttpTunnel(false); form.setFieldsValue({ host: '', port: 0, @@ -1483,6 +1522,11 @@ const ConnectionModal: React.FC<{ proxyPort: 1080, proxyUser: '', proxyPassword: '', + useHttpTunnel: false, + httpTunnelHost: '', + httpTunnelPort: 8080, + httpTunnelUser: '', + httpTunnelPassword: '', mysqlTopology: 'single', redisTopology: 'single', mongoTopology: 'single', @@ -1505,6 +1549,7 @@ const ConnectionModal: React.FC<{ const defaultUser = type === 'clickhouse' ? 'default' : 'root'; const sslCapableType = supportsSSLForType(type); setUseSSL(false); + setUseHttpTunnel(false); form.setFieldsValue({ user: defaultUser, database: '', @@ -1513,6 +1558,11 @@ const ConnectionModal: React.FC<{ sslMode: sslCapableType ? 'preferred' : undefined, sslCertPath: sslCapableType ? '' : undefined, sslKeyPath: sslCapableType ? '' : undefined, + useHttpTunnel: false, + httpTunnelHost: '', + httpTunnelPort: 8080, + httpTunnelUser: '', + httpTunnelPassword: '', mysqlTopology: 'single', redisTopology: 'single', mongoTopology: 'single', @@ -1665,6 +1715,8 @@ const ConnectionModal: React.FC<{ useProxy: false, proxyType: 'socks5', proxyPort: 1080, + useHttpTunnel: false, + httpTunnelPort: 8080, timeout: 30, uri: '', mysqlTopology: 'single', @@ -1693,7 +1745,14 @@ const ConnectionModal: React.FC<{ } if (changed.useSSL !== undefined) setUseSSL(changed.useSSL); if (changed.useSSH !== undefined) setUseSSH(changed.useSSH); - if (changed.useProxy !== undefined) setUseProxy(changed.useProxy); + if (changed.useProxy !== undefined) { + const enabledProxy = !!changed.useProxy; + setUseProxy(enabledProxy); + if (enabledProxy && form.getFieldValue('useHttpTunnel')) { + form.setFieldValue('useHttpTunnel', false); + setUseHttpTunnel(false); + } + } if (changed.proxyType !== undefined) { const nextType = String(changed.proxyType || 'socks5').toLowerCase(); if (nextType === 'http') { @@ -1708,6 +1767,20 @@ const ConnectionModal: React.FC<{ } } } + if (changed.useHttpTunnel !== undefined) { + const enabledHttpTunnel = !!changed.useHttpTunnel; + setUseHttpTunnel(enabledHttpTunnel); + if (enabledHttpTunnel && form.getFieldValue('useProxy')) { + form.setFieldValue('useProxy', false); + setUseProxy(false); + } + if (enabledHttpTunnel) { + const currentPort = Number(form.getFieldValue('httpTunnelPort') || 0); + if (!currentPort || currentPort <= 0) { + form.setFieldValue('httpTunnelPort', 8080); + } + } + } // Type change handled by step 1, but keep sync if select changes (hidden now) if (changed.type !== undefined) setDbType(changed.type); if (changed.redisTopology !== undefined) { @@ -2194,6 +2267,35 @@ const ConnectionModal: React.FC<{ )} + + + 使用 HTTP 隧道(独立代理) + + + {useHttpTunnel && ( +
+
+ + + + + + +
+
+ + + + + + +
+ + 与“使用代理”互斥,启用后将通过 HTTP CONNECT 建立独立隧道。 + +
+ )} + { try { if (val === null) return NULL; if (typeof val === 'object') { + if (!Array.isArray(val) && !isPlainObject(val)) { + return String(val); + } const cached = objectCellPreviewCache.get(val); if (cached !== undefined) { return cached; } + const topLevelSize = Array.isArray(val) ? val.length : Object.keys(val || {}).length; + if (topLevelSize > 80) { + const summary = Array.isArray(val) ? `[Array(${topLevelSize})]` : `{Object(${topLevelSize})}`; + objectCellPreviewCache.set(val, summary); + return summary; + } try { const nextText = JSON.stringify(val); const previewText = nextText.length > TABLE_CELL_PREVIEW_MAX_CHARS ? `${nextText.slice(0, TABLE_CELL_PREVIEW_MAX_CHARS)}…` : nextText; @@ -191,6 +200,26 @@ const isCellValueEqualForDiff = (left: any, right: any): boolean => { return toFormText(left) === toFormText(right); }; +// 渲染阶段轻量比较:避免对象值在 shouldCellUpdate 中反复深度序列化导致卡顿。 +const isCellValueEqualForRender = (left: any, right: any): boolean => { + if (left === right) return true; + const leftNullish = left === null || left === undefined; + const rightNullish = right === null || right === undefined; + if (leftNullish || rightNullish) return leftNullish && rightNullish; + + const leftType = typeof left; + const rightType = typeof right; + if (leftType === 'object' || rightType === 'object') { + // 对象仅按引用比较;真正的值差异在提交保存时再做严格比对。 + return false; + } + + if (leftType === 'string' || rightType === 'string') { + return normalizeDateTimeString(String(left)) === normalizeDateTimeString(String(right)); + } + return left === right; +}; + const INLINE_EDIT_MAX_CHARS = 2000; const shouldOpenModalEditor = (val: any): boolean => { @@ -2067,7 +2096,7 @@ const DataGrid: React.FC = ({ shouldCellUpdate: (record: Item, prevRecord: Item) => { const rowKeyChanged = record?.[GONAVI_ROW_KEY] !== prevRecord?.[GONAVI_ROW_KEY]; if (rowKeyChanged) return true; - return !isCellValueEqualForDiff(record?.[key], prevRecord?.[key]); + return !isCellValueEqualForRender(record?.[key], prevRecord?.[key]); }, onHeaderCell: (column: any) => ({ width: column.width, diff --git a/frontend/src/components/DataSyncModal.tsx b/frontend/src/components/DataSyncModal.tsx index 769885a..57c4033 100644 --- a/frontend/src/components/DataSyncModal.tsx +++ b/frontend/src/components/DataSyncModal.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useRef } from 'react'; +import React, { useState, useEffect, useMemo, useRef } from 'react'; import { Modal, Form, Select, Button, message, Steps, Transfer, Card, Alert, Divider, Typography, Progress, Checkbox, Table, Drawer, Tabs } from 'antd'; import { useStore } from '../store'; import { DBGetDatabases, DBGetTables, DataSync, DataSyncAnalyze, DataSyncPreview } from '../../wailsjs/go/app/App'; @@ -31,6 +31,118 @@ type TableOps = { selectedDeletePks?: string[]; }; +const quoteSqlIdent = (dbType: string, ident: string): string => { + const raw = String(ident || '').trim(); + if (!raw) return raw; + const t = String(dbType || '').toLowerCase(); + if (t === 'mysql' || t === 'mariadb' || t === 'diros' || t === 'sphinx' || t === 'clickhouse' || t === 'tdengine') { + return `\`${raw.replace(/`/g, '``')}\``; + } + if (t === 'sqlserver') { + return `[${raw.replace(/]/g, ']]')}]`; + } + return `"${raw.replace(/"/g, '""')}"`; +}; + +const quoteSqlTable = (dbType: string, tableName: string): string => { + const raw = String(tableName || '').trim(); + if (!raw) return raw; + if (!raw.includes('.')) return quoteSqlIdent(dbType, raw); + return raw + .split('.') + .map((part) => quoteSqlIdent(dbType, part)) + .join('.'); +}; + +const toSqlLiteral = (value: any, dbType: string): string => { + if (value === null || value === undefined) return 'NULL'; + if (typeof value === 'number') return Number.isFinite(value) ? String(value) : 'NULL'; + if (typeof value === 'bigint') return value.toString(); + if (typeof value === 'boolean') { + const t = String(dbType || '').toLowerCase(); + if (t === 'sqlserver') return value ? '1' : '0'; + return value ? 'TRUE' : 'FALSE'; + } + if (value instanceof Date) { + return `'${value.toISOString().replace(/'/g, "''")}'`; + } + if (typeof value === 'object') { + try { + return `'${JSON.stringify(value).replace(/'/g, "''")}'`; + } catch { + return `'${String(value).replace(/'/g, "''")}'`; + } + } + return `'${String(value).replace(/'/g, "''")}'`; +}; + +const buildSqlPreview = ( + previewData: any, + tableName: string, + dbType: string, + ops?: TableOps, +): { sqlText: string; statementCount: number } => { + if (!previewData || !tableName) return { sqlText: '', statementCount: 0 }; + const tableExpr = quoteSqlTable(dbType, tableName); + const pkCol = String(previewData.pkColumn || 'id'); + const statements: string[] = []; + + const insertRows = Array.isArray(previewData.inserts) ? previewData.inserts : []; + const updateRows = Array.isArray(previewData.updates) ? previewData.updates : []; + const deleteRows = Array.isArray(previewData.deletes) ? previewData.deletes : []; + + const selectedInsert = new Set((ops?.selectedInsertPks || []).map((v) => String(v))); + const selectedUpdate = new Set((ops?.selectedUpdatePks || []).map((v) => String(v))); + const selectedDelete = new Set((ops?.selectedDeletePks || []).map((v) => String(v))); + + if (ops?.insert !== false) { + insertRows.forEach((rowWrap: any) => { + const pk = String(rowWrap?.pk ?? ''); + if (selectedInsert.size > 0 && !selectedInsert.has(pk)) return; + const row = rowWrap?.row || {}; + const columns = Object.keys(row); + if (columns.length === 0) return; + const colExpr = columns.map((c) => quoteSqlIdent(dbType, c)).join(', '); + const valExpr = columns.map((c) => toSqlLiteral(row[c], dbType)).join(', '); + statements.push(`INSERT INTO ${tableExpr} (${colExpr}) VALUES (${valExpr});`); + }); + } + + if (ops?.update !== false) { + updateRows.forEach((rowWrap: any) => { + const pk = String(rowWrap?.pk ?? ''); + if (selectedUpdate.size > 0 && !selectedUpdate.has(pk)) return; + const source = rowWrap?.source || {}; + const changedColumns = Array.isArray(rowWrap?.changedColumns) + ? rowWrap.changedColumns + : Object.keys(source).filter((k) => k !== pkCol); + const setCols = changedColumns.filter((c: string) => String(c) !== pkCol); + if (setCols.length === 0) return; + const setExpr = setCols + .map((c: string) => `${quoteSqlIdent(dbType, c)} = ${toSqlLiteral(source[c], dbType)}`) + .join(', '); + statements.push( + `UPDATE ${tableExpr} SET ${setExpr} WHERE ${quoteSqlIdent(dbType, pkCol)} = ${toSqlLiteral(pk, dbType)};`, + ); + }); + } + + if (ops?.delete) { + deleteRows.forEach((rowWrap: any) => { + const pk = String(rowWrap?.pk ?? ''); + if (selectedDelete.size > 0 && !selectedDelete.has(pk)) return; + statements.push( + `DELETE FROM ${tableExpr} WHERE ${quoteSqlIdent(dbType, pkCol)} = ${toSqlLiteral(pk, dbType)};`, + ); + }); + } + + return { + sqlText: statements.join('\n'), + statementCount: statements.length, + }; +}; + const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, onClose }) => { const connections = useStore((state) => state.connections); const [currentStep, setCurrentStep] = useState(0); @@ -152,32 +264,38 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, setSourceConnId(connId); setSourceDb(''); const conn = connections.find(c => c.id === connId); - if (conn) { - setLoading(true); - try { - const res = await DBGetDatabases(normalizeConnConfig(conn) as any); - if (res.success) { - setSourceDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username)); - } - } catch(e) { message.error("Failed to fetch source databases"); } - setLoading(false); - } + if (conn) { + setLoading(true); + try { + const res = await DBGetDatabases(normalizeConnConfig(conn) as any); + if (res.success) { + const dbRows = Array.isArray(res.data) ? res.data : []; + setSourceDbs(dbRows + .map((r: any) => r?.Database || r?.database || r?.username) + .filter((name: any) => typeof name === 'string' && name.trim() !== '')); + } + } catch(e) { message.error("Failed to fetch source databases"); } + setLoading(false); + } }; const handleTargetConnChange = async (connId: string) => { setTargetConnId(connId); setTargetDb(''); const conn = connections.find(c => c.id === connId); - if (conn) { - setLoading(true); - try { - const res = await DBGetDatabases(normalizeConnConfig(conn) as any); - if (res.success) { - setTargetDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username)); - } - } catch(e) { message.error("Failed to fetch target databases"); } - setLoading(false); - } + if (conn) { + setLoading(true); + try { + const res = await DBGetDatabases(normalizeConnConfig(conn) as any); + if (res.success) { + const dbRows = Array.isArray(res.data) ? res.data : []; + setTargetDbs(dbRows + .map((r: any) => r?.Database || r?.database || r?.username) + .filter((name: any) => typeof name === 'string' && name.trim() !== '')); + } + } catch(e) { message.error("Failed to fetch target databases"); } + setLoading(false); + } }; const nextToTables = async () => { @@ -189,14 +307,17 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, try { const conn = connections.find(c => c.id === sourceConnId); if (conn) { - const config = normalizeConnConfig(conn, sourceDb); - const res = await DBGetTables(config as any, sourceDb); - if (res.success) { - // DBGetTables returns [{Table: "name"}, ...] - const tables = (res.data as any[]).map((row: any) => row.Table || row.table || row.TABLE_NAME || Object.values(row)[0]); - setAllTables(tables as string[]); - setCurrentStep(1); - } else { + const config = normalizeConnConfig(conn, sourceDb); + const res = await DBGetTables(config as any, sourceDb); + if (res.success) { + // DBGetTables returns [{Table: "name"}, ...] + const tableRows = Array.isArray(res.data) ? res.data : []; + const tables = tableRows + .map((row: any) => row?.Table || row?.table || row?.TABLE_NAME || Object.values(row || {})[0]) + .filter((name: any) => typeof name === 'string' && name.trim() !== ''); + setAllTables(tables as string[]); + setCurrentStep(1); + } else { message.error(res.message); } } @@ -402,6 +523,13 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, ); }; + const previewSql = useMemo(() => { + if (!previewData || !previewTable) return { sqlText: '', statementCount: 0 }; + const targetType = String(connections.find(c => c.id === targetConnId)?.config?.type || ''); + const ops = tableOptions[previewTable] || { insert: true, update: true, delete: false }; + return buildSqlPreview(previewData, previewTable, targetType, ops); + }, [previewData, previewTable, targetConnId, connections, tableOptions]); + return ( <> void }> = ({ open, /> ) + }, + { + key: 'sql', + label: `SQL(${previewSql.statementCount})`, + children: ( +
+ +
+ 共 {previewSql.statementCount} 条语句(预览数据最多 200 条/类型) + +
+
+                                        {previewSql.sqlText || '-- 当前勾选范围下无 SQL 可预览'}
+                                    
+
+ ) } ]} /> diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index 2e66344..69294d1 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -48,6 +48,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { const [editorHeight, setEditorHeight] = useState(300); const editorRef = useRef(null); const monacoRef = useRef(null); + const lastExternalQueryRef = useRef(tab.query || ''); const dragRef = useRef<{ startY: number, startHeight: number } | null>(null); const tablesRef = useRef<{dbName: string, tableName: string}[]>([]); // Store tables for autocomplete (cross-db) const allColumnsRef = useRef<{dbName: string, tableName: string, name: string, type: string}[]>([]); // Store all columns (cross-db) @@ -95,10 +96,30 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { connectionsRef.current = connections; }, [connections]); + const getCurrentQuery = () => { + const val = editorRef.current?.getValue?.(); + if (typeof val === 'string') return val; + return query || ''; + }; + + const syncQueryToEditor = (sql: string) => { + const next = sql || ''; + setQuery(next); + const editor = editorRef.current; + if (editor && editor.getValue?.() !== next) { + editor.setValue(next); + } + }; + // If opening a saved query, load its SQL useEffect(() => { - if (tab.query) setQuery(tab.query); - }, [tab.query]); + const incoming = tab.query || ''; + if (incoming === lastExternalQueryRef.current) { + return; + } + lastExternalQueryRef.current = incoming; + syncQueryToEditor(incoming || 'SELECT * FROM '); + }, [tab.id, tab.query]); // Fetch Database List useEffect(() => { @@ -557,8 +578,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { const handleFormat = () => { try { - const formatted = format(query, { language: 'mysql', keywordCase: sqlFormatOptions.keywordCase }); - setQuery(formatted); + const formatted = format(getCurrentQuery(), { language: 'mysql', keywordCase: sqlFormatOptions.keywordCase }); + syncQueryToEditor(formatted); } catch (e) { message.error("格式化失败: SQL 语法可能有误"); } @@ -1045,7 +1066,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { }; const handleRun = async () => { - if (!query.trim()) return; + const currentQuery = getCurrentQuery(); + if (!currentQuery.trim()) return; if (!currentDb) { message.error("请先选择数据库"); return; @@ -1086,7 +1108,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { }; try { - const rawSQL = getSelectedSQL() || query; + const rawSQL = getSelectedSQL() || currentQuery; const dbType = String((config as any).type || 'mysql'); const normalizedDbType = dbType.trim().toLowerCase(); const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';'); @@ -1367,7 +1389,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { saveQuery({ id: tab.id.startsWith('saved-') ? tab.id : `saved-${Date.now()}`, name: values.name, - sql: query, + sql: getCurrentQuery(), connectionId: currentConnectionId, dbName: currentDb || tab.dbName || '', createdAt: Date.now() @@ -1512,7 +1534,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => { height="100%" defaultLanguage="sql" theme={darkMode ? "transparent-dark" : "transparent-light"} - value={query} + defaultValue={query} onChange={(val) => setQuery(val || '')} onMount={handleEditorDidMount} options={{ diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 6420955..2fa3fdd 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -382,6 +382,16 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> password: readString(rawProxy.password, rawProxy.Password, cloned.proxyPassword, cloned.ProxyPassword), }; const hasProxyDetail = Boolean(normalizedProxy.host || normalizedProxy.user || normalizedProxy.password); + const rawHttpTunnel = (cloned.httpTunnel ?? cloned.HTTPTunnel ?? {}) as Record; + const normalizedHttpTunnel = { + host: readString(rawHttpTunnel.host, rawHttpTunnel.Host, cloned.httpTunnelHost, cloned.HttpTunnelHost), + port: readNumber(8080, rawHttpTunnel.port, rawHttpTunnel.Port, cloned.httpTunnelPort, cloned.HttpTunnelPort), + user: readString(rawHttpTunnel.user, rawHttpTunnel.User, cloned.httpTunnelUser, cloned.HttpTunnelUser), + password: readString(rawHttpTunnel.password, rawHttpTunnel.Password, cloned.httpTunnelPassword, cloned.HttpTunnelPassword), + }; + const hasHttpTunnelDetail = Boolean(normalizedHttpTunnel.host || normalizedHttpTunnel.user || normalizedHttpTunnel.password); + const normalizedUseHttpTunnel = readBool(hasHttpTunnelDetail, cloned.useHttpTunnel, cloned.UseHTTPTunnel); + const normalizedUseProxy = !normalizedUseHttpTunnel && readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy); const rawHosts = Array.isArray(cloned.hosts) ? cloned.hosts @@ -394,8 +404,10 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> ...(cloned as SavedConnection['config']), useSSH: readBool(hasSSHDetail, cloned.useSSH, cloned.UseSSH), ssh: normalizedSSH, - useProxy: readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy), + useProxy: normalizedUseProxy, proxy: normalizedProxy, + useHttpTunnel: normalizedUseHttpTunnel, + httpTunnel: normalizedHttpTunnel, hosts: normalizedHosts, timeout: readNumber(30, cloned.timeout, cloned.Timeout), }; @@ -645,10 +657,15 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> } case 'oracle': case 'dm': - if (!safeDbName) { - return [{ sql: `SELECT VIEW_NAME AS view_name FROM USER_VIEWS ORDER BY VIEW_NAME` }]; - } - return [{ sql: `SELECT OWNER AS schema_name, VIEW_NAME AS view_name FROM ALL_VIEWS WHERE OWNER = '${safeDbName.toUpperCase()}' ORDER BY VIEW_NAME` }]; + return normalizeMetadataQuerySpecs([ + { sql: `SELECT VIEW_NAME AS view_name FROM USER_VIEWS ORDER BY VIEW_NAME` }, + { sql: `SELECT OWNER AS schema_name, VIEW_NAME AS view_name FROM ALL_VIEWS WHERE OWNER = USER ORDER BY VIEW_NAME` }, + { + sql: safeDbName + ? `SELECT OWNER AS schema_name, VIEW_NAME AS view_name FROM ALL_VIEWS WHERE OWNER = '${safeDbName.toUpperCase()}' ORDER BY VIEW_NAME` + : '', + }, + ]); case 'sqlite': return [{ sql: `SELECT name AS view_name FROM sqlite_master WHERE type = 'view' ORDER BY name` }]; case 'duckdb': @@ -731,10 +748,15 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> } case 'oracle': case 'dm': - if (!safeDbName) { - return [{ sql: `SELECT OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM USER_OBJECTS WHERE OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` }]; - } - return [{ sql: `SELECT OWNER AS schema_name, OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM ALL_OBJECTS WHERE OWNER = '${safeDbName.toUpperCase()}' AND OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` }]; + return normalizeMetadataQuerySpecs([ + { sql: `SELECT OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM USER_OBJECTS WHERE OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` }, + { sql: `SELECT OWNER AS schema_name, OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM ALL_OBJECTS WHERE OWNER = USER AND OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` }, + { + sql: safeDbName + ? `SELECT OWNER AS schema_name, OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM ALL_OBJECTS WHERE OWNER = '${safeDbName.toUpperCase()}' AND OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` + : '', + }, + ]); case 'duckdb': return [{ sql: `SELECT schema_name, function_name AS routine_name, 'FUNCTION' AS routine_type FROM duckdb_functions() WHERE internal = false AND lower(function_type) = 'macro' AND COALESCE(macro_definition, '') <> '' ORDER BY schema_name, function_name`, diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 42f3fb6..e3b44f5 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -231,6 +231,18 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => { user: toTrimmedString(proxyRaw.user), password: toTrimmedString(proxyRaw.password), }; + const httpTunnelRaw = (raw.httpTunnel && typeof raw.httpTunnel === 'object') + ? raw.httpTunnel as Record + : ((raw.HTTPTunnel && typeof raw.HTTPTunnel === 'object') ? raw.HTTPTunnel as Record : {}); + const httpTunnel = { + host: toTrimmedString(httpTunnelRaw.host ?? raw.httpTunnelHost), + port: normalizePort(httpTunnelRaw.port ?? raw.httpTunnelPort, 8080), + user: toTrimmedString(httpTunnelRaw.user ?? raw.httpTunnelUser), + password: toTrimmedString(httpTunnelRaw.password ?? raw.httpTunnelPassword), + }; + const supportsNetworkTunnel = type !== 'sqlite' && type !== 'duckdb'; + const useHttpTunnel = supportsNetworkTunnel && (raw.useHttpTunnel === true || raw.UseHTTPTunnel === true); + const useProxy = supportsNetworkTunnel && !!raw.useProxy && !useHttpTunnel; const safeConfig: ConnectionConfig & Record = { ...raw, @@ -247,8 +259,10 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => { sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : '', useSSH: !!raw.useSSH, ssh, - useProxy: !!raw.useProxy, + useProxy, proxy, + useHttpTunnel, + httpTunnel, uri: toTrimmedString(raw.uri).slice(0, MAX_URI_LENGTH), hosts: sanitizeAddressList(raw.hosts), topology: raw.topology === 'replica' ? 'replica' : (raw.topology === 'cluster' ? 'cluster' : 'single'), diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 501a854..96ac6da 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -14,6 +14,13 @@ export interface ProxyConfig { password?: string; } +export interface HTTPTunnelConfig { + host: string; + port: number; + user?: string; + password?: string; +} + export interface ConnectionConfig { type: string; host: string; @@ -30,6 +37,8 @@ export interface ConnectionConfig { ssh?: SSHConfig; useProxy?: boolean; proxy?: ProxyConfig; + useHttpTunnel?: boolean; + httpTunnel?: HTTPTunnelConfig; driver?: string; dsn?: string; timeout?: number; diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index bca7b39..2de678a 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -48,6 +48,24 @@ export namespace connection { return a; } } + export class HTTPTunnelConfig { + host: string; + port: number; + user?: string; + password?: string; + + static createFrom(source: any = {}) { + return new HTTPTunnelConfig(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.host = source["host"]; + this.port = source["port"]; + this.user = source["user"]; + this.password = source["password"]; + } + } export class ProxyConfig { type: string; host: string; @@ -104,6 +122,8 @@ export namespace connection { ssh: SSHConfig; useProxy?: boolean; proxy?: ProxyConfig; + useHttpTunnel?: boolean; + httpTunnel?: HTTPTunnelConfig; driver?: string; dsn?: string; timeout?: number; @@ -142,6 +162,8 @@ export namespace connection { this.ssh = this.convertValues(source["ssh"], SSHConfig); this.useProxy = source["useProxy"]; this.proxy = this.convertValues(source["proxy"], ProxyConfig); + this.useHttpTunnel = source["useHttpTunnel"]; + this.httpTunnel = this.convertValues(source["httpTunnel"], HTTPTunnelConfig); this.driver = source["driver"]; this.dsn = source["dsn"]; this.timeout = source["timeout"]; @@ -179,6 +201,7 @@ export namespace connection { } } + export class QueryResult { success: boolean; message: string; diff --git a/internal/app/app.go b/internal/app/app.go index 789f7be..0709a27 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -96,6 +96,9 @@ func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.Conn if !normalized.UseProxy { normalized.Proxy = connection.ProxyConfig{} } + if !normalized.UseHTTPTunnel { + normalized.HTTPTunnel = connection.HTTPTunnelConfig{} + } if isFileDatabaseType(normalized.Type) { dsn := strings.TrimSpace(normalized.Host) @@ -124,6 +127,8 @@ func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.Conn normalized.MongoAuthMechanism = "" normalized.MongoReplicaUser = "" normalized.MongoReplicaPassword = "" + normalized.UseHTTPTunnel = false + normalized.HTTPTunnel = connection.HTTPTunnelConfig{} } return normalized @@ -303,6 +308,12 @@ func formatConnSummary(config connection.ConnectionConfig) string { b.WriteString(" 代理认证=已配置") } } + if config.UseHTTPTunnel { + b.WriteString(fmt.Sprintf(" HTTP隧道=%s:%d", strings.TrimSpace(config.HTTPTunnel.Host), config.HTTPTunnel.Port)) + if strings.TrimSpace(config.HTTPTunnel.User) != "" { + b.WriteString(" HTTP隧道认证=已配置") + } + } if config.Type == "custom" { driver := strings.TrimSpace(config.Driver) diff --git a/internal/app/db_proxy.go b/internal/app/db_proxy.go index bdf2311..e3228b6 100644 --- a/internal/app/db_proxy.go +++ b/internal/app/db_proxy.go @@ -12,8 +12,35 @@ import ( func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { config := raw + if config.UseHTTPTunnel { + if config.UseProxy { + return connection.ConnectionConfig{}, fmt.Errorf("HTTP 隧道与普通代理不能同时启用") + } + tunnelHost := strings.TrimSpace(config.HTTPTunnel.Host) + if tunnelHost == "" { + return connection.ConnectionConfig{}, fmt.Errorf("HTTP 隧道主机不能为空") + } + tunnelPort := config.HTTPTunnel.Port + if tunnelPort <= 0 { + tunnelPort = 8080 + } + if tunnelPort > 65535 { + return connection.ConnectionConfig{}, fmt.Errorf("HTTP 隧道端口无效:%d", config.HTTPTunnel.Port) + } + + config.UseProxy = true + config.Proxy = connection.ProxyConfig{ + Type: "http", + Host: tunnelHost, + Port: tunnelPort, + User: strings.TrimSpace(config.HTTPTunnel.User), + Password: config.HTTPTunnel.Password, + } + } if !config.UseProxy { config.Proxy = connection.ProxyConfig{} + config.UseHTTPTunnel = false + config.HTTPTunnel = connection.HTTPTunnelConfig{} return config, nil } @@ -22,6 +49,8 @@ func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.Con return connection.ConnectionConfig{}, err } config.Proxy = normalizedProxy + config.UseHTTPTunnel = false + config.HTTPTunnel = connection.HTTPTunnelConfig{} if config.UseSSH { sshPort := config.SSH.Port diff --git a/internal/app/global_proxy.go b/internal/app/global_proxy.go index 4dc8686..4361782 100644 --- a/internal/app/global_proxy.go +++ b/internal/app/global_proxy.go @@ -110,7 +110,7 @@ func (a *App) GetGlobalProxyConfig() connection.QueryResult { func applyGlobalProxyToConnection(config connection.ConnectionConfig) connection.ConnectionConfig { effective := config - if effective.UseProxy { + if effective.UseProxy || effective.UseHTTPTunnel { return effective } if isFileDatabaseType(effective.Type) { diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index d1ef4a9..d8529a9 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -547,6 +547,13 @@ func sqlSnippet(query string) string { return q[:max] + "..." } +func ensureNonNilSlice[T any](items []T) []T { + if items == nil { + return make([]T, 0) + } + return items +} + func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult { runConfig := normalizeRunConfig(config, "") dbInst, err := a.getDatabase(runConfig) @@ -571,7 +578,7 @@ func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.Quer return connection.QueryResult{Success: false, Message: err.Error()} } - var resData []map[string]string + resData := make([]map[string]string, 0, len(dbs)) for _, name := range dbs { resData = append(resData, map[string]string{"Database": name}) } @@ -604,7 +611,7 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con return connection.QueryResult{Success: false, Message: err.Error()} } - var resData []map[string]string + resData := make([]map[string]string, 0, len(tables)) for _, name := range tables { resData = append(resData, map[string]string{"Table": name}) } @@ -786,7 +793,7 @@ func (a *App) DBGetColumns(config connection.ConnectionConfig, dbName string, ta return connection.QueryResult{Success: false, Message: err.Error()} } - return connection.QueryResult{Success: true, Data: columns} + return connection.QueryResult{Success: true, Data: ensureNonNilSlice(columns)} } func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { @@ -803,7 +810,7 @@ func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, ta return connection.QueryResult{Success: false, Message: err.Error()} } - return connection.QueryResult{Success: true, Data: indexes} + return connection.QueryResult{Success: true, Data: ensureNonNilSlice(indexes)} } func (a *App) DBGetForeignKeys(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { @@ -820,7 +827,7 @@ func (a *App) DBGetForeignKeys(config connection.ConnectionConfig, dbName string return connection.QueryResult{Success: false, Message: err.Error()} } - return connection.QueryResult{Success: true, Data: fks} + return connection.QueryResult{Success: true, Data: ensureNonNilSlice(fks)} } func (a *App) DBGetTriggers(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { @@ -837,7 +844,7 @@ func (a *App) DBGetTriggers(config connection.ConnectionConfig, dbName string, t return connection.QueryResult{Success: false, Message: err.Error()} } - return connection.QueryResult{Success: true, Data: triggers} + return connection.QueryResult{Success: true, Data: ensureNonNilSlice(triggers)} } func (a *App) DropView(config connection.ConnectionConfig, dbName string, viewName string) connection.QueryResult { @@ -975,5 +982,5 @@ func (a *App) DBGetAllColumns(config connection.ConnectionConfig, dbName string) return connection.QueryResult{Success: false, Message: err.Error()} } - return connection.QueryResult{Success: true, Data: cols} + return connection.QueryResult{Success: true, Data: ensureNonNilSlice(cols)} } diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index 344233e..07a13cc 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -2536,6 +2536,9 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa return installedDriverPackage{}, fmt.Errorf("导入本地驱动代理失败:%w", copyErr) } } + if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil { + return installedDriverPackage{}, validateErr + } hash, hashErr := hashFileSHA256(executablePath) if hashErr != nil { @@ -2793,11 +2796,15 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut info, err := os.Stat(executablePath) if err == nil && !info.IsDir() { - hash, hashErr := hashFileSHA256(executablePath) - if hashErr != nil { - return "", "", fmt.Errorf("读取已安装 %s 驱动代理摘要失败:%w", displayName, hashErr) + if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil { + _ = os.Remove(executablePath) + } else { + hash, hashErr := hashFileSHA256(executablePath) + if hashErr != nil { + return "", "", fmt.Errorf("读取已安装 %s 驱动代理摘要失败:%w", displayName, hashErr) + } + return fmt.Sprintf("local://existing/%s-driver-agent", driverType), hash, nil } - return fmt.Sprintf("local://existing/%s-driver-agent", driverType), hash, nil } if err == nil && info.IsDir() { return "", "", fmt.Errorf("%s 驱动代理路径被目录占用:%s", displayName, executablePath) @@ -2814,6 +2821,10 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil { return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr) } + if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil { + _ = os.Remove(executablePath) + return "", "", validateErr + } hash, hashErr := hashFileSHA256(executablePath) if hashErr != nil { return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr) @@ -2901,6 +2912,10 @@ func downloadOptionalDriverAgentBinary(a *App, definition driverDefinition, urlT if chmodErr := os.Chmod(executablePath, 0o755); chmodErr != nil && stdRuntime.GOOS != "windows" { return "", fmt.Errorf("设置代理权限失败:%w", chmodErr) } + if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil { + _ = os.Remove(executablePath) + return "", validateErr + } return hash, nil } @@ -3009,6 +3024,10 @@ func downloadOptionalDriverAgentFromBundle(a *App, definition driverDefinition, if chmodErr := os.Chmod(executablePath, 0o755); chmodErr != nil && stdRuntime.GOOS != "windows" { return "", "", fmt.Errorf("设置驱动代理权限失败:%w", chmodErr) } + if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil { + _ = os.Remove(executablePath) + return "", "", validateErr + } hash, err := hashFileSHA256(executablePath) if err != nil { return "", "", fmt.Errorf("计算驱动代理摘要失败:%w", err) @@ -3334,6 +3353,7 @@ func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL } func findExistingOptionalDriverAgentCandidate(definition driverDefinition, targetPath string) (string, bool) { + driverType := normalizeDriverType(definition.Type) targetAbs, _ := filepath.Abs(targetPath) candidates := resolveOptionalDriverAgentCandidatePaths(definition) for _, candidate := range candidates { @@ -3349,9 +3369,13 @@ func findExistingOptionalDriverAgentCandidate(definition driverDefinition, targe continue } info, statErr := os.Stat(absPath) - if statErr == nil && !info.IsDir() { - return absPath, true + if statErr != nil || info.IsDir() { + continue } + if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, absPath); validateErr != nil { + continue + } + return absPath, true } return "", false } diff --git a/internal/app/methods_redis.go b/internal/app/methods_redis.go index 1b626b0..3bf8956 100644 --- a/internal/app/methods_redis.go +++ b/internal/app/methods_redis.go @@ -23,12 +23,20 @@ var ( // getRedisClient gets or creates a Redis client from cache func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisClient, error) { - key := getRedisClientCacheKey(config) + effectiveConfig := applyGlobalProxyToConnection(config) + connectConfig, proxyErr := resolveDialConfigWithProxy(effectiveConfig) + if proxyErr != nil { + wrapped := wrapConnectError(effectiveConfig, proxyErr) + logger.Error(wrapped, "Redis 代理准备失败:%s", formatRedisConnSummary(effectiveConfig)) + return nil, wrapped + } + + key := getRedisClientCacheKey(connectConfig) shortKey := key if len(shortKey) > 12 { shortKey = shortKey[:12] } - logger.Infof("获取 Redis 连接:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey) + logger.Infof("获取 Redis 连接:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey) redisCacheMu.Lock() defer redisCacheMu.Unlock() @@ -47,21 +55,20 @@ func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisCli logger.Infof("创建 Redis 客户端实例:缓存Key=%s", shortKey) client := redis.NewRedisClient() - if err := client.Connect(config); err != nil { - logger.Error(err, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey) - return nil, err + if err := client.Connect(connectConfig); err != nil { + wrapped := wrapConnectError(effectiveConfig, err) + logger.Error(wrapped, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey) + return nil, wrapped } redisCache[key] = client - logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey) + logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey) return client, nil } func getRedisClientCacheKey(config connection.ConnectionConfig) string { - if !config.UseSSH { - config.SSH = connection.SSHConfig{} - } - b, _ := json.Marshal(config) + normalized := normalizeCacheKeyConfig(config) + b, _ := json.Marshal(normalized) sum := sha256.Sum256(b) return hex.EncodeToString(sum[:]) } @@ -91,6 +98,26 @@ func formatRedisConnSummary(config connection.ConnectionConfig) string { b.WriteString(" 用户=") b.WriteString(config.SSH.User) } + if config.UseProxy { + b.WriteString(" 代理=") + b.WriteString(strings.ToLower(strings.TrimSpace(config.Proxy.Type))) + b.WriteString("://") + b.WriteString(config.Proxy.Host) + b.WriteString(":") + b.WriteString(strconv.Itoa(config.Proxy.Port)) + if strings.TrimSpace(config.Proxy.User) != "" { + b.WriteString(" 代理认证=已配置") + } + } + if config.UseHTTPTunnel { + b.WriteString(" HTTP隧道=") + b.WriteString(strings.TrimSpace(config.HTTPTunnel.Host)) + b.WriteString(":") + b.WriteString(strconv.Itoa(config.HTTPTunnel.Port)) + if strings.TrimSpace(config.HTTPTunnel.User) != "" { + b.WriteString(" HTTP隧道认证=已配置") + } + } return b.String() } diff --git a/internal/connection/types.go b/internal/connection/types.go index bc88873..bac9ec7 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -18,39 +18,49 @@ type ProxyConfig struct { Password string `json:"password,omitempty"` } +// HTTPTunnelConfig holds independent HTTP CONNECT tunnel details +type HTTPTunnelConfig struct { + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` +} + // ConnectionConfig holds database connection details including SSH type ConnectionConfig struct { - Type string `json:"type"` - Host string `json:"host"` - Port int `json:"port"` - User string `json:"user"` - Password string `json:"password"` - SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection - Database string `json:"database"` - UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch - SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable - SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng) - SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng) - UseSSH bool `json:"useSSH"` - SSH SSHConfig `json:"ssh"` - UseProxy bool `json:"useProxy,omitempty"` - Proxy ProxyConfig `json:"proxy,omitempty"` - Driver string `json:"driver,omitempty"` // For custom connection - DSN string `json:"dsn,omitempty"` // For custom connection - Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30) - RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15) - URI string `json:"uri,omitempty"` // Connection URI for copy/paste - Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port - Topology string `json:"topology,omitempty"` // single | replica | cluster - MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user - MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password - ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name - AuthSource string `json:"authSource,omitempty"` // MongoDB authSource - ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference - MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme - MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism - MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user - MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password + Type string `json:"type"` + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user"` + Password string `json:"password"` + SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection + Database string `json:"database"` + UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch + SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable + SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng) + SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng) + UseSSH bool `json:"useSSH"` + SSH SSHConfig `json:"ssh"` + UseProxy bool `json:"useProxy,omitempty"` + Proxy ProxyConfig `json:"proxy,omitempty"` + UseHTTPTunnel bool `json:"useHttpTunnel,omitempty"` + HTTPTunnel HTTPTunnelConfig `json:"httpTunnel,omitempty"` + Driver string `json:"driver,omitempty"` // For custom connection + DSN string `json:"dsn,omitempty"` // For custom connection + Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30) + RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15) + URI string `json:"uri,omitempty"` // Connection URI for copy/paste + Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port + Topology string `json:"topology,omitempty"` // single | replica | cluster + MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user + MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password + ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name + AuthSource string `json:"authSource,omitempty"` // MongoDB authSource + ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference + MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme + MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism + MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user + MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password } // QueryResult is the standard response format for Wails methods diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index dcf18e6..f1d5811 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -107,7 +107,9 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig if readTimeout < minClickHouseReadTimeout { readTimeout = minClickHouseReadTimeout } + protocol := detectClickHouseProtocol(config) opts := &clickhouse.Options{ + Protocol: protocol, Addr: []string{ net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), }, @@ -125,6 +127,46 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig return opts } +func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Protocol { + uriText := strings.ToLower(strings.TrimSpace(config.URI)) + if strings.HasPrefix(uriText, "http://") || strings.HasPrefix(uriText, "https://") { + return clickhouse.HTTP + } + if config.Port == 8123 || config.Port == 8443 { + return clickhouse.HTTP + } + return clickhouse.Native +} + +func isClickHouseProtocolMismatch(err error) bool { + if err == nil { + return false + } + text := strings.ToLower(strings.TrimSpace(err.Error())) + if text == "" { + return false + } + return strings.Contains(text, "unexpected packet [72]") || + (strings.Contains(text, "unexpected packet") && strings.Contains(text, "handshake")) || + strings.Contains(text, "http response to https client") || + strings.Contains(text, "malformed http response") +} + +func withClickHouseProtocol(config connection.ConnectionConfig, protocol clickhouse.Protocol) connection.ConnectionConfig { + next := config + switch protocol { + case clickhouse.HTTP: + if next.Port == 0 { + next.Port = 8123 + } + default: + if next.Port == 0 { + next.Port = defaultClickHousePort + } + } + return next +} + func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error { if supported, reason := DriverRuntimeSupportStatus("clickhouse"); !supported { if strings.TrimSpace(reason) == "" { @@ -176,23 +218,41 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error { var failures []string for idx, attempt := range attempts { - c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(attempt)) - if err := c.Ping(); err != nil { - failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) - if c.conn != nil { - _ = c.conn.Close() - c.conn = nil + primaryProtocol := detectClickHouseProtocol(attempt) + protocols := []clickhouse.Protocol{primaryProtocol} + if primaryProtocol == clickhouse.Native { + protocols = append(protocols, clickhouse.HTTP) + } else { + protocols = append(protocols, clickhouse.Native) + } + + for pIdx, protocol := range protocols { + protocolConfig := withClickHouseProtocol(attempt, protocol) + c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(protocolConfig)) + if err := c.Ping(); err != nil { + failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %v", idx+1, protocol.String(), err)) + if c.conn != nil { + _ = c.conn.Close() + c.conn = nil + } + if pIdx == 0 && !isClickHouseProtocolMismatch(err) { + // 首次连接不是协议误配特征,避免无谓重试次协议。 + break + } + continue } - continue + if idx > 0 { + logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接") + } + if pIdx > 0 { + logger.Warnf("ClickHouse 已自动切换连接协议为 %s(常见于 8123/8443 HTTP 端口)", protocol.String()) + } + return nil } - if idx > 0 { - logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接") - } - return nil } _ = c.Close() - return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) + return fmt.Errorf("连接建立后验证失败(可检查 ClickHouse 端口与协议是否匹配:Native=9000/9440,HTTP=8123/8443):%s", strings.Join(failures, ";")) } func (c *ClickHouseDB) Close() error { diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go index 5080540..5cceb0a 100644 --- a/internal/db/dameng_impl.go +++ b/internal/db/dameng_impl.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/url" + "sort" "strconv" "strings" "time" @@ -204,24 +205,82 @@ func (d *DamengDB) Exec(query string) (int64, error) { } func (d *DamengDB) GetDatabases() ([]string, error) { - // DM: List Users/Schemas - data, _, err := d.Query("SELECT username FROM dba_users") - if err != nil { - // Fallback if dba_users not accessible - data, _, err = d.Query("SELECT username FROM all_users") + // 达梦将「用户/模式」作为数据库列表来源,不同权限下可见口径不同。 + // 这里采用多查询口径聚合,避免仅依赖单一视图导致“少库”。 + queries := []string{ + "SELECT USERNAME AS DATABASE_NAME FROM SYS.DBA_USERS ORDER BY USERNAME", + "SELECT USERNAME AS DATABASE_NAME FROM DBA_USERS ORDER BY USERNAME", + "SELECT USERNAME AS DATABASE_NAME FROM ALL_USERS ORDER BY USERNAME", + "SELECT USERNAME AS DATABASE_NAME FROM USER_USERS", + "SELECT DISTINCT OWNER AS DATABASE_NAME FROM ALL_TABLES ORDER BY OWNER", + } + + seen := make(map[string]struct{}) + dbs := make([]string, 0, 64) + var lastErr error + success := false + + for _, q := range queries { + data, _, err := d.Query(q) if err != nil { - return nil, err + lastErr = err + continue + } + success = true + for _, row := range data { + name := getDamengRowString(row, "DATABASE_NAME", "USERNAME", "OWNER", "SCHEMA_NAME") + if name == "" { + // 回退到第一列,兼容驱动返回列名差异。 + for _, v := range row { + text := strings.TrimSpace(fmt.Sprintf("%v", v)) + if text == "" || strings.EqualFold(text, "") { + continue + } + name = text + break + } + } + if name == "" { + continue + } + key := strings.ToUpper(name) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + dbs = append(dbs, name) } } - var dbs []string - for _, row := range data { - if val, ok := row["USERNAME"]; ok { - dbs = append(dbs, fmt.Sprintf("%v", val)) - } + + if !success && lastErr != nil { + return nil, lastErr } + + sort.Slice(dbs, func(i, j int) bool { + return strings.ToUpper(dbs[i]) < strings.ToUpper(dbs[j]) + }) return dbs, nil } +func getDamengRowString(row map[string]interface{}, keys ...string) string { + if len(row) == 0 { + return "" + } + for _, key := range keys { + for k, v := range row { + if !strings.EqualFold(strings.TrimSpace(k), strings.TrimSpace(key)) { + continue + } + text := strings.TrimSpace(fmt.Sprintf("%v", v)) + if text == "" || strings.EqualFold(text, "") { + return "" + } + return text + } + } + return "" +} + func (d *DamengDB) GetTables(dbName string) ([]string, error) { query := fmt.Sprintf("SELECT owner, table_name FROM all_tables WHERE owner = '%s' ORDER BY table_name", strings.ToUpper(dbName)) if dbName == "" { diff --git a/internal/db/driver_agent_binary_check.go b/internal/db/driver_agent_binary_check.go new file mode 100644 index 0000000..762c720 --- /dev/null +++ b/internal/db/driver_agent_binary_check.go @@ -0,0 +1,74 @@ +package db + +import ( + "debug/pe" + "fmt" + "runtime" + "strings" +) + +const ( + peMachineI386 uint16 = 0x014c + peMachineAmd64 uint16 = 0x8664 + peMachineArm64 uint16 = 0xaa64 +) + +func windowsMachineLabel(machine uint16) string { + switch machine { + case peMachineI386: + return "windows-386" + case peMachineAmd64: + return "windows-amd64" + case peMachineArm64: + return "windows-arm64" + default: + return fmt.Sprintf("windows-unknown(0x%04x)", machine) + } +} + +func expectedWindowsMachineForGoArch(goarch string) (uint16, string, bool) { + switch strings.ToLower(strings.TrimSpace(goarch)) { + case "386": + return peMachineI386, "windows-386", true + case "amd64": + return peMachineAmd64, "windows-amd64", true + case "arm64": + return peMachineArm64, "windows-arm64", true + default: + return 0, "", false + } +} + +func validateWindowsExecutableMachine(pathText string) error { + file, err := pe.Open(pathText) + if err != nil { + return fmt.Errorf("无法识别为有效的 Windows 可执行文件:%w", err) + } + defer file.Close() + + expectedMachine, expectedLabel, ok := expectedWindowsMachineForGoArch(runtime.GOARCH) + if !ok { + return nil + } + actualMachine := file.FileHeader.Machine + if actualMachine != expectedMachine { + return fmt.Errorf("可执行文件架构不兼容(文件=%s,当前进程=%s)", windowsMachineLabel(actualMachine), expectedLabel) + } + return nil +} + +// ValidateOptionalDriverAgentExecutable 校验可选驱动代理二进制是否可在当前进程中执行。 +// 当前主要用于 Windows 下的 PE 架构兼容性校验,避免升级后复用到错误架构的旧代理。 +func ValidateOptionalDriverAgentExecutable(driverType string, executablePath string) error { + pathText := strings.TrimSpace(executablePath) + if pathText == "" { + return fmt.Errorf("%s 驱动代理路径为空", driverDisplayName(driverType)) + } + if runtime.GOOS != "windows" { + return nil + } + if err := validateWindowsExecutableMachine(pathText); err != nil { + return fmt.Errorf("%s 驱动代理不可用:%w", driverDisplayName(driverType), err) + } + return nil +} diff --git a/internal/db/driver_support.go b/internal/db/driver_support.go index 517a81a..db00717 100644 --- a/internal/db/driver_support.go +++ b/internal/db/driver_support.go @@ -194,6 +194,9 @@ func optionalGoDriverRuntimeReady(driverType string) (bool, string) { if statErr != nil || info.IsDir() { return false, fmt.Sprintf("%s 驱动代理缺失,请在驱动管理中重新安装启用", driverDisplayName(normalized)) } + if validateErr := ValidateOptionalDriverAgentExecutable(normalized, executablePath); validateErr != nil { + return false, fmt.Sprintf("%s;请在驱动管理中重新安装启用", validateErr.Error()) + } return true, "" } diff --git a/internal/db/driver_support_test.go b/internal/db/driver_support_test.go index 8dc5f62..002fba0 100644 --- a/internal/db/driver_support_test.go +++ b/internal/db/driver_support_test.go @@ -65,11 +65,22 @@ func TestManagedDriverRequiresInstallMarker(t *testing.T) { if err != nil { t.Fatalf("解析 mariadb 代理路径失败: %v", err) } - if err := os.WriteFile(executablePath, []byte("placeholder"), 0o755); err != nil { - t.Fatalf("写入 mariadb 代理占位文件失败: %v", err) - } if runtime.GOOS == "windows" { - _ = os.Chmod(executablePath, 0o644) + selfPath, selfErr := os.Executable() + if selfErr != nil { + t.Fatalf("获取测试进程路径失败: %v", selfErr) + } + content, readErr := os.ReadFile(selfPath) + if readErr != nil { + t.Fatalf("读取测试进程失败: %v", readErr) + } + if err := os.WriteFile(executablePath, content, 0o755); err != nil { + t.Fatalf("写入 mariadb 代理占位可执行文件失败: %v", err) + } + } else { + if err := os.WriteFile(executablePath, []byte("placeholder"), 0o755); err != nil { + t.Fatalf("写入 mariadb 代理占位文件失败: %v", err) + } } supported, reason := DriverRuntimeSupportStatus("mariadb") diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index f1357a8..6dfd2e5 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -623,28 +623,16 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet } defer tx.Rollback() - quoteIdent := func(name string) string { - n := strings.TrimSpace(name) - n = strings.Trim(n, "\"") - n = strings.ReplaceAll(n, "\"", "\"\"") - if n == "" { - return "\"\"" - } - return `"` + n + `"` - } - - schema := "" - table := strings.TrimSpace(tableName) - if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { - schema = strings.TrimSpace(parts[0]) - table = strings.TrimSpace(parts[1]) + schema, table := splitKingbaseQualifiedTable(tableName) + if table == "" { + return fmt.Errorf("table name required") } qualifiedTable := "" if schema != "" { - qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) + qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table)) } else { - qualifiedTable = quoteIdent(table) + qualifiedTable = quoteKingbaseIdent(table) } // 1. Deletes @@ -654,7 +642,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet idx := 0 for k, v := range pk { idx++ - wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(wheres) == 0 { @@ -674,7 +662,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet for k, v := range update.Values { idx++ - sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } @@ -685,7 +673,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet var wheres []string for k, v := range update.Keys { idx++ - wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } @@ -708,7 +696,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet for k, v := range row { idx++ - cols = append(cols, quoteIdent(k)) + cols = append(cols, quoteKingbaseIdent(k)) placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) args = append(args, v) } @@ -726,6 +714,67 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet return tx.Commit() } +func normalizeKingbaseIdentifier(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + + // 兼容 JSON/字符串转义后传入的标识符:\"schema\" -> "schema" + value = strings.ReplaceAll(value, `\"`, `"`) + value = strings.TrimSpace(value) + + // 兼容异常多重包裹引号(例如 ""schema""、""""schema"""")。 + // strings.Trim 会移除两端连续引号,迭代后可收敛到纯标识符。 + for i := 0; i < 4; i++ { + next := strings.TrimSpace(strings.Trim(value, `"`)) + if next == value { + break + } + value = next + } + + // 兼容其他方言可能残留的引用形式 + if len(value) >= 2 && strings.HasPrefix(value, "`") && strings.HasSuffix(value, "`") { + value = strings.TrimSpace(strings.Trim(value, "`")) + } + if len(value) >= 2 && strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + value = strings.TrimSpace(value[1 : len(value)-1]) + } + + return value +} + +func quoteKingbaseIdent(name string) string { + n := normalizeKingbaseIdentifier(name) + n = strings.ReplaceAll(n, `"`, `""`) + if n == "" { + return "\"\"" + } + return `"` + n + `"` +} + +func splitKingbaseQualifiedTable(tableName string) (schema string, table string) { + raw := strings.TrimSpace(tableName) + if raw == "" { + return "", "" + } + + if parts := strings.SplitN(raw, ".", 2); len(parts) == 2 { + schema = normalizeKingbaseIdentifier(parts[0]) + table = normalizeKingbaseIdentifier(parts[1]) + if table == "" { + return "", normalizeKingbaseIdentifier(raw) + } + if schema == "" { + return "", table + } + return schema, table + } + + return "", normalizeKingbaseIdentifier(raw) +} + func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { // dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。 query := ` diff --git a/internal/db/kingbase_impl_test.go b/internal/db/kingbase_impl_test.go new file mode 100644 index 0000000..eca6eaa --- /dev/null +++ b/internal/db/kingbase_impl_test.go @@ -0,0 +1,74 @@ +//go:build gonavi_full_drivers || gonavi_kingbase_driver + +package db + +import "testing" + +func TestNormalizeKingbaseIdentifier(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server", want: "ldf_server"}, + {name: "quoted", in: `"ldf_server"`, want: "ldf_server"}, + {name: "double quoted", in: `""ldf_server""`, want: "ldf_server"}, + {name: "quad quoted", in: `""""ldf_server""""`, want: "ldf_server"}, + {name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"}, + {name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"}, + {name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeKingbaseIdentifier(tt.in); got != tt.want { + t.Fatalf("normalizeKingbaseIdentifier(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestQuoteKingbaseIdent(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server", want: `"ldf_server"`}, + {name: "double quoted", in: `""ldf_server""`, want: `"ldf_server"`}, + {name: "escaped quoted", in: `\"ldf_server\"`, want: `"ldf_server"`}, + {name: "with embedded quote", in: `ab"cd`, want: `"ab""cd"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := quoteKingbaseIdent(tt.in); got != tt.want { + t.Fatalf("quoteKingbaseIdent(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestSplitKingbaseQualifiedTable(t *testing.T) { + tests := []struct { + name string + in string + wantSchema string + wantTable string + }{ + {name: "plain qualified", in: "ldf_server.t_user", wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "double quoted qualified", in: `""ldf_server"".""t_user""`, wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "escaped qualified", in: `\"ldf_server\".\"t_user\"`, wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "bracket qualified", in: "[ldf_server].[t_user]", wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "table only", in: `""t_user""`, wantSchema: "", wantTable: "t_user"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotSchema, gotTable := splitKingbaseQualifiedTable(tt.in) + if gotSchema != tt.wantSchema || gotTable != tt.wantTable { + t.Fatalf("splitKingbaseQualifiedTable(%q) = (%q, %q), want (%q, %q)", tt.in, gotSchema, gotTable, tt.wantSchema, tt.wantTable) + } + }) + } +} diff --git a/internal/db/optional_driver_agent_impl.go b/internal/db/optional_driver_agent_impl.go index 1b83902..2579b7c 100644 --- a/internal/db/optional_driver_agent_impl.go +++ b/internal/db/optional_driver_agent_impl.go @@ -9,8 +9,10 @@ import ( "io" "os" "os/exec" + "runtime" "strings" "sync" + "syscall" "time" "GoNavi-Wails/internal/connection" @@ -94,6 +96,9 @@ func newOptionalDriverAgentClient(driverType string, executablePath string) (*op return nil, fmt.Errorf("创建 %s 驱动代理 stderr 失败:%w", driverDisplayName(driverType), err) } if err := cmd.Start(); err != nil { + if isWindowsExecutableMachineMismatch(err) { + return nil, fmt.Errorf("启动 %s 驱动代理失败:%w(检测到驱动代理与当前系统架构不兼容,请在驱动管理中重新安装启用)", driverDisplayName(driverType), err) + } return nil, fmt.Errorf("启动 %s 驱动代理失败:%w", driverDisplayName(driverType), err) } @@ -107,6 +112,30 @@ func newOptionalDriverAgentClient(driverType string, executablePath string) (*op return client, nil } +func isWindowsExecutableMachineMismatch(err error) bool { + if err == nil || runtime.GOOS != "windows" { + return false + } + var errno syscall.Errno + if errors.As(err, &errno) && errno == syscall.Errno(216) { + return true + } + text := strings.ToLower(strings.TrimSpace(err.Error())) + if text == "" { + return false + } + if strings.Contains(text, "not compatible with the version of windows") { + return true + } + if strings.Contains(text, "win32") && strings.Contains(text, "compatible") { + return true + } + if strings.Contains(text, "不是有效的win32应用程序") || strings.Contains(text, "无法在win32模式下运行") { + return true + } + return false +} + func (c *optionalDriverAgentClient) captureStderr(stderr io.Reader) { scanner := bufio.NewScanner(stderr) buffer := make([]byte, 0, 8<<10) diff --git a/internal/db/query_value.go b/internal/db/query_value.go index 83fdf7f..fa28bd7 100644 --- a/internal/db/query_value.go +++ b/internal/db/query_value.go @@ -8,6 +8,7 @@ import ( "reflect" "strconv" "strings" + "time" "unicode" "unicode/utf8" ) @@ -86,6 +87,16 @@ func normalizeCompositeQueryValue(v interface{}) interface{} { items[i] = normalizeQueryValue(rv.Index(i).Interface()) } return items + case reflect.Struct: + // 部分驱动(如 Kingbase)会返回复杂结构体值,直接透传会导致前端渲染和比较开销激增。 + // 统一降级为可读字符串,避免对象深层序列化触发 UI 卡顿。 + if tm, ok := v.(time.Time); ok { + return tm.Format(time.RFC3339Nano) + } + if stringer, ok := v.(fmt.Stringer); ok { + return stringer.String() + } + return fmt.Sprintf("%v", v) default: return normalizeUnsafeIntegerForJS(rv, v) } diff --git a/internal/db/query_value_test.go b/internal/db/query_value_test.go index b05977e..285344e 100644 --- a/internal/db/query_value_test.go +++ b/internal/db/query_value_test.go @@ -2,7 +2,9 @@ package db import ( "encoding/json" + "fmt" "testing" + "time" ) type duckMapLike map[any]any @@ -165,3 +167,31 @@ func TestNormalizeQueryValueWithDBType_JSONNumber(t *testing.T) { }) } } + +type customStructValue struct { + Name string + Age int +} + +func (v customStructValue) String() string { + return fmt.Sprintf("%s-%d", v.Name, v.Age) +} + +func TestNormalizeQueryValueWithDBType_StructToString(t *testing.T) { + got := normalizeQueryValueWithDBType(customStructValue{Name: "alice", Age: 18}, "") + if got != "alice-18" { + t.Fatalf("结构体应降级为可读字符串,实际=%v(%T)", got, got) + } +} + +func TestNormalizeQueryValueWithDBType_TimeStructToRFC3339(t *testing.T) { + input := time.Date(2026, 3, 5, 18, 30, 15, 123456789, time.UTC) + got := normalizeQueryValueWithDBType(input, "") + text, ok := got.(string) + if !ok { + t.Fatalf("time.Time 应转为字符串,实际=%v(%T)", got, got) + } + if text != "2026-03-05T18:30:15.123456789Z" { + t.Fatalf("time.Time 规整值异常,实际=%s", text) + } +}