From eaa45f17fdce46b7b20e6f6b81054cf377efd9fc Mon Sep 17 00:00:00 2001 From: Syngnat <92659908+Syngnat@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:40:35 +0800 Subject: [PATCH] Release/0.5.7 (#226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🎨 style(DataGrid): 清理冗余代码与静态分析告警 - 类型重构:通过修正 React Context 的函数签名解决了 void 类型的链式调用错误 - 代码精简:利用 Nullish Coalescing (??) 优化组件配置项降级逻辑,剥离无意义的隐式 undefined 赋值 - 工具链适配:适配 IDE 拼写检查与 Promise strict rules,确保全文件零警 * 🔧 fix(db/kingbase_impl): 修复标识符无条件加双引号导致SQL语法报错 - quoteKingbaseIdent 改为条件引用,仅对大写字母、保留字、特殊字符的标识符添加双引号 - 新增 kingbaseIdentNeedsQuote 判断标识符是否需要引用 - 新增 isKingbaseReservedWord 检测常见SQL保留字 - 补充 TestQuoteKingbaseIdent、TestKingbaseIdentNeedsQuote 单测覆盖各场景 - refs #176 * 🔧 fix(release,db/kingbase_impl): 修复金仓默认 schema 并静默生成 DMG - Kingbase:在 current_schema() 为 public 时探测候选 schema,并通过 DSN search_path 重连,兼容未限定 schema 的查询 - 候选优先级:数据库名/用户名同名 schema(存在性校验),否则仅在“唯一用户 schema 有表”场景兜底 - 避免连接污染:每次 Connect 重置探测结果,重连成功后替换连接并关闭旧连接 - 打包脚本:create-dmg 增加 --sandbox-safe,避免构建时自动弹出/打开挂载窗口 - 产物格式:强制 --format UDZO,并将 rw.*.dmg/UDRW 中间产物转换为可分发 DMG - 校验门禁:增加 hdiutil verify,失败时保留 .app 便于排查,同时修正卷图标探测并补 ad-hoc 签名 * 🐛 fix(connection/redis): 修复 Redis URI 用户名处理导致认证失败 - Redis URI 解析回填 user 字段,兼容 redis://user:pass@... 与 redis://:pass@... - 生成 URI 时按需输出 user/password,避免丢失用户名信息 - Redis 类型默认用户名置空,并在构建配置时清理历史默认 root - 避免 go-redis 触发 ACL AUTH(user, pass) 导致 WRONGPASS - refs #212 * 🔧 fix(release,ssh): 修复 SSH 误判连接成功并纠正 DMG 打包结构 - SSH 缓存 key 纳入认证指纹(password/keyPath),避免改错凭证仍复用旧连接/端口转发 - MySQL/MariaDB/Doris:SSH 隧道建立失败直接返回错误,不再回退直连导致测试误判成功 - 新增最小单测覆盖 SSH cache key 与 UseSSH 异常路径 - build-release.sh:create-dmg 使用 staging 目录作为 source,避免 DMG 根目录变成 Contents - refs #213 * fix: KingBase 连接后自动设置 search_path,修复自定义 schema 下表查询报 relation does not exist 的问题 (#215) * 🔧 fix(driver/kingbase,mongodb): 修复外置驱动事务引用与连接测试链路问题 - 金仓外置驱动链路增加表名与变更字段归一化,修复 ApplyChanges 场景下双引号转义异常导致的 SQL 语法错误 - 新增金仓公共标识符工具并复用到 kingbase_impl 与 optional_driver_agent_impl,统一处理多重转义、schema.table 拆分与引用规范 - 金仓代理连接后自动探测并设置 search_path,降低查询时必须手写 schema 前缀的概率 - MongoDB 连接参数改为显式 host/hosts 优先,避免被 URI 中 localhost 覆盖;代理链路保留目标地址不再改写为本地地址 - 连接测试增加前后端超时收敛与日志增强,避免长时间转圈;连接错误文案在未启用 TLS 时移除误导性的“SSL”前缀 - 统一日志级别为 INFO/WARN/ERROR,默认日志目录收敛到 ~/.GoNavi/Logs,并补充驱动构建脚本 build-driver-agents.sh * 🔧 fix(release/sidebar): 统一跨平台UPX压缩并修复PG函数列表查询兼容性 - 构建脚本新增通用 UPX 压缩函数,覆盖 macOS、Linux、Windows 产物 - 本地打包改为强制压缩策略:未安装 upx、压缩失败或校验失败直接终止 - macOS 打包在签名前压缩 .app 主程序并执行 upx -t 校验 - Linux 打包在生成 tar.gz 前压缩可执行文件并执行 upx -t 校验 - GitHub Release 与测试构建流程补齐 macOS/Linux/Windows 的 upx 安装与压缩步骤 - PostgreSQL/PG-like 函数元数据查询增加多路兼容 SQL,修复函数列表不显示问题 - refs #221 - refs #222 --------- Co-authored-by: Syngnat Co-authored-by: 凌封 <49424247+fengin@users.noreply.github.com> --- .github/workflows/release.yml | 75 ++- .../workflows/test-build-all-platforms.yml | 67 +- build-driver-agents.sh | 228 +++++++ build-release.sh | 351 ++++++++--- frontend/index.html | 17 + frontend/src/App.tsx | 143 +++-- frontend/src/components/ConnectionModal.tsx | 87 ++- frontend/src/components/DataGrid.tsx | 594 +++++++++++++----- frontend/src/components/Sidebar.tsx | 15 +- frontend/src/main.tsx | 30 + frontend/src/store.ts | 95 ++- internal/app/app.go | 79 ++- internal/app/app_connect_error_test.go | 84 +++ internal/app/db_proxy.go | 4 +- internal/app/db_proxy_test.go | 64 ++ internal/app/global_proxy.go | 49 +- internal/app/methods_db.go | 19 +- internal/app/methods_db_timeout_test.go | 31 + internal/app/methods_driver.go | 55 +- internal/db/diros_impl.go | 23 +- internal/db/kingbase_identifier_utils.go | 164 +++++ internal/db/kingbase_identifier_utils_test.go | 52 ++ internal/db/kingbase_impl.go | 161 +++-- internal/db/kingbase_impl_test.go | 49 +- internal/db/mariadb_impl.go | 22 +- internal/db/mongodb_impl.go | 82 ++- internal/db/mongodb_impl_uri_test.go | 39 ++ internal/db/mongodb_impl_v1.go | 82 ++- internal/db/mongodb_impl_v1_uri_test.go | 25 + internal/db/mysql_impl.go | 22 +- internal/db/mysql_ssh_test.go | 26 + internal/db/optional_driver_agent_impl.go | 276 ++++++++ .../db/optional_driver_agent_impl_test.go | 75 ++- internal/logger/logger.go | 56 +- internal/redis/redis_impl.go | 24 + internal/redis/redis_impl_test.go | 81 +++ internal/ssh/ssh.go | 94 ++- internal/ssh/ssh_cache_key_test.go | 46 ++ 38 files changed, 2997 insertions(+), 489 deletions(-) create mode 100755 build-driver-agents.sh create mode 100644 internal/app/app_connect_error_test.go create mode 100644 internal/app/db_proxy_test.go create mode 100644 internal/app/methods_db_timeout_test.go create mode 100644 internal/db/kingbase_identifier_utils.go create mode 100644 internal/db/kingbase_identifier_utils_test.go create mode 100644 internal/db/mongodb_impl_uri_test.go create mode 100644 internal/db/mongodb_impl_v1_uri_test.go create mode 100644 internal/db/mysql_ssh_test.go create mode 100644 internal/redis/redis_impl_test.go create mode 100644 internal/ssh/ssh_cache_key_test.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7dd9b87..84f14a5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -88,6 +88,24 @@ jobs: with: node-version: '20' + - name: Install UPX (macOS) + if: contains(matrix.platform, 'darwin') + run: | + brew install upx + upx --version + + - name: Install UPX (Windows) + if: contains(matrix.platform, 'windows') + shell: pwsh + run: | + choco install upx --no-progress -y + $upxCmd = Get-Command upx -ErrorAction SilentlyContinue + if ($null -eq $upxCmd) { + Write-Error "❌ 未检测到 upx,无法保证 Windows 产物经过压缩" + exit 1 + } + & upx --version + # Linux Dependencies (GTK3, WebKit2GTK required by Wails) - name: Install Linux Dependencies if: contains(matrix.platform, 'linux') @@ -102,6 +120,9 @@ jobs: sudo apt-get install -y libwebkit2gtk-4.0-dev fi + sudo apt-get install -y upx-ucl || sudo apt-get install -y upx + upx --version + # AppImage 运行/打包可能需要 FUSE2。不同发行版/版本包名不同,做兼容兜底。 sudo apt-get install -y libfuse2 || sudo apt-get install -y libfuse2t64 || true @@ -277,6 +298,23 @@ jobs: exit 1 fi APP_NAME=$(basename "$APP_PATH") + + APP_BIN=$(find "$APP_PATH/Contents/MacOS" -maxdepth 1 -type f | head -n 1) + if [ -z "$APP_BIN" ]; then + echo "❌ 未找到 macOS 应用主程序,无法进行 UPX 压缩!" + exit 1 + fi + BEFORE_BYTES=$(wc -c <"$APP_BIN" | tr -d '[:space:]') + echo "🗜️ 正在使用 UPX 压缩 macOS 可执行文件: $APP_BIN ..." + upx --best --lzma --force "$APP_BIN" + upx -t "$APP_BIN" + AFTER_BYTES=$(wc -c <"$APP_BIN" | tr -d '[:space:]') + if [ "$AFTER_BYTES" -lt "$BEFORE_BYTES" ]; then + SAVED_BYTES=$((BEFORE_BYTES - AFTER_BYTES)) + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" -v s="$SAVED_BYTES" 'BEGIN { printf "✅ macOS UPX 压缩完成:%.2fMB -> %.2fMB,减少 %.2fMB\n", b/1024/1024, a/1024/1024, s/1024/1024 }' + else + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" 'BEGIN { printf "ℹ️ macOS UPX 压缩完成:%.2fMB -> %.2fMB\n", b/1024/1024, a/1024/1024 }' + fi echo "🔏 正在进行 Ad-hoc 签名..." # 注意:Ad-hoc + hardened runtime(--options runtime)在未配置 entitlements 时, @@ -301,7 +339,7 @@ jobs: mv "$DMG_NAME" "../../$FINAL_NAME" # Windows Packaging - - name: Package Windows Portable Zip + - name: Package Windows EXE if: contains(matrix.platform, 'windows') shell: pwsh run: | @@ -312,7 +350,6 @@ jobs: } $target = "${{ matrix.build_name }}" $finalExeName = "GoNavi-$version-${{ matrix.os_name }}-${{ matrix.arch_name }}${{ matrix.artifact_suffix }}.exe" - $finalZipName = "GoNavi-$version-${{ matrix.os_name }}-${{ matrix.arch_name }}${{ matrix.artifact_suffix }}.zip" if (Test-Path "$target.exe") { $finalExe = "$target.exe" @@ -324,11 +361,25 @@ jobs: exit 1 } - Write-Host "📦 生成 Windows 可执行文件 $finalExeName..." - Copy-Item -LiteralPath $finalExe -Destination "..\\..\\$finalExeName" -Force + $upxCmd = Get-Command upx -ErrorAction SilentlyContinue + if ($null -eq $upxCmd) { + Write-Error "❌ 未找到 upx,无法保证 Windows 产物经过压缩" + exit 1 + } + $beforeBytes = (Get-Item -LiteralPath $finalExe).Length + Write-Host "🗜️ 使用 UPX 压缩 $finalExe ..." + & upx --best --lzma --force $finalExe | Out-Host + & upx -t $finalExe | Out-Host + $afterBytes = (Get-Item -LiteralPath $finalExe).Length + if ($afterBytes -lt $beforeBytes) { + $savedBytes = $beforeBytes - $afterBytes + Write-Host ("✅ UPX 压缩完成:{0:N2}MB -> {1:N2}MB,减少 {2:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB), ($savedBytes / 1MB)) + } else { + Write-Host ("ℹ️ UPX 压缩完成:{0:N2}MB -> {1:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB)) + } - Write-Host "📦 生成 Windows 压缩包 $finalZipName..." - Compress-Archive -LiteralPath $finalExe -DestinationPath "..\\..\\$finalZipName" -Force + Write-Host "📦 输出 Windows 可执行文件 $finalExeName..." + Copy-Item -LiteralPath $finalExe -Destination "..\\..\\$finalExeName" -Force # Linux Packaging (tar.gz and AppImage) - name: Package Linux @@ -347,6 +398,17 @@ jobs: fi chmod +x "$TARGET" + BEFORE_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]') + echo "🗜️ 正在使用 UPX 压缩 Linux 可执行文件: $TARGET ..." + upx --best --lzma --force "$TARGET" + upx -t "$TARGET" + AFTER_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]') + if [ "$AFTER_BYTES" -lt "$BEFORE_BYTES" ]; then + SAVED_BYTES=$((BEFORE_BYTES - AFTER_BYTES)) + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" -v s="$SAVED_BYTES" 'BEGIN { printf "✅ Linux UPX 压缩完成:%.2fMB -> %.2fMB,减少 %.2fMB\n", b/1024/1024, a/1024/1024, s/1024/1024 }' + else + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" 'BEGIN { printf "ℹ️ Linux UPX 压缩完成:%.2fMB -> %.2fMB\n", b/1024/1024, a/1024/1024 }' + fi # 1. Create tar.gz echo "📦 正在打包 $TAR_NAME..." @@ -419,7 +481,6 @@ jobs: path: | GoNavi-*.dmg GoNavi-*.exe - GoNavi-*.zip GoNavi-*.tar.gz GoNavi-*.AppImage drivers/** diff --git a/.github/workflows/test-build-all-platforms.yml b/.github/workflows/test-build-all-platforms.yml index 29ffe9d..6646ece 100644 --- a/.github/workflows/test-build-all-platforms.yml +++ b/.github/workflows/test-build-all-platforms.yml @@ -93,6 +93,24 @@ jobs: with: node-version: '20' + - name: Install UPX (macOS) + if: contains(matrix.platform, 'darwin') + run: | + brew install upx + upx --version + + - name: Install UPX (Windows) + if: contains(matrix.platform, 'windows') + shell: pwsh + run: | + choco install upx --no-progress -y + $upxCmd = Get-Command upx -ErrorAction SilentlyContinue + if ($null -eq $upxCmd) { + Write-Error "❌ 未检测到 upx,无法保证 Windows 测试产物经过压缩" + exit 1 + } + & upx --version + - name: Install Linux Dependencies if: contains(matrix.platform, 'linux') run: | @@ -105,6 +123,9 @@ jobs: sudo apt-get install -y libwebkit2gtk-4.0-dev fi + sudo apt-get install -y upx-ucl || sudo apt-get install -y upx + upx --version + sudo apt-get install -y libfuse2 || sudo apt-get install -y libfuse2t64 || true LINUXDEPLOY_URL="https://github.com/linuxdeploy/linuxdeploy/releases/download/continuous/linuxdeploy-x86_64.AppImage" @@ -242,6 +263,22 @@ jobs: exit 1 fi APP_NAME=$(basename "$APP_PATH") + APP_BIN=$(find "$APP_PATH/Contents/MacOS" -maxdepth 1 -type f | head -n 1) + if [ -z "$APP_BIN" ]; then + echo "未找到 macOS 应用主程序,无法进行 UPX 压缩" + exit 1 + fi + BEFORE_BYTES=$(wc -c <"$APP_BIN" | tr -d '[:space:]') + echo "🗜️ 使用 UPX 压缩 macOS 可执行文件: $APP_BIN ..." + upx --best --lzma --force "$APP_BIN" + upx -t "$APP_BIN" + AFTER_BYTES=$(wc -c <"$APP_BIN" | tr -d '[:space:]') + if [ "$AFTER_BYTES" -lt "$BEFORE_BYTES" ]; then + SAVED_BYTES=$((BEFORE_BYTES - AFTER_BYTES)) + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" -v s="$SAVED_BYTES" 'BEGIN { printf "✅ macOS UPX 压缩完成:%.2fMB -> %.2fMB,减少 %.2fMB\n", b/1024/1024, a/1024/1024, s/1024/1024 }' + else + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" 'BEGIN { printf "ℹ️ macOS UPX 压缩完成:%.2fMB -> %.2fMB\n", b/1024/1024, a/1024/1024 }' + fi codesign --force --deep --sign - "$APP_NAME" ZIP_NAME="GoNavi-${LABEL}-${{ matrix.os_name }}-${{ matrix.arch_name }}-run${GITHUB_RUN_NUMBER}.zip" DMG_NAME="GoNavi-${LABEL}-${{ matrix.os_name }}-${{ matrix.arch_name }}-run${GITHUB_RUN_NUMBER}.dmg" @@ -270,7 +307,6 @@ jobs: Set-Location build/bin $target = "${{ matrix.build_name }}" $finalExeName = "GoNavi-$label-${{ matrix.os_name }}-${{ matrix.arch_name }}-run$env:GITHUB_RUN_NUMBER.exe" - $finalZipName = "GoNavi-$label-${{ matrix.os_name }}-${{ matrix.arch_name }}-run$env:GITHUB_RUN_NUMBER.zip" if (Test-Path "$target.exe") { $finalExe = "$target.exe" } elseif (Test-Path "$target") { @@ -280,11 +316,25 @@ jobs: Write-Error "未找到构建产物 '$target'" exit 1 } + $upxCmd = Get-Command upx -ErrorAction SilentlyContinue + if ($null -eq $upxCmd) { + Write-Error "❌ 未找到 upx,无法保证 Windows 测试产物经过压缩" + exit 1 + } + $beforeBytes = (Get-Item -LiteralPath $finalExe).Length + Write-Host "🗜️ 使用 UPX 压缩 $finalExe ..." + & upx --best --lzma --force $finalExe | Out-Host + & upx -t $finalExe | Out-Host + $afterBytes = (Get-Item -LiteralPath $finalExe).Length + if ($afterBytes -lt $beforeBytes) { + $savedBytes = $beforeBytes - $afterBytes + Write-Host ("✅ UPX 压缩完成:{0:N2}MB -> {1:N2}MB,减少 {2:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB), ($savedBytes / 1MB)) + } else { + Write-Host ("ℹ️ UPX 压缩完成:{0:N2}MB -> {1:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB)) + } New-Item -ItemType Directory -Force -Path ..\..\artifacts | Out-Null Copy-Item -LiteralPath $finalExe -Destination "..\..\artifacts\$finalExeName" -Force - Compress-Archive -LiteralPath $finalExe -DestinationPath "..\..\artifacts\$finalZipName" -Force Get-FileHash "..\..\artifacts\$finalExeName" -Algorithm SHA256 | ForEach-Object { "{0} *{1}" -f $_.Hash.ToLower(), (Split-Path $_.Path -Leaf) } | Out-File "..\..\artifacts\$finalExeName.sha256" -Encoding ascii - Get-FileHash "..\..\artifacts\$finalZipName" -Algorithm SHA256 | ForEach-Object { "{0} *{1}" -f $_.Hash.ToLower(), (Split-Path $_.Path -Leaf) } | Out-File "..\..\artifacts\$finalZipName.sha256" -Encoding ascii - name: Package Linux if: contains(matrix.platform, 'linux') @@ -306,6 +356,17 @@ jobs: exit 1 fi chmod +x "$TARGET" + BEFORE_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]') + echo "🗜️ 使用 UPX 压缩 Linux 可执行文件: $TARGET ..." + upx --best --lzma --force "$TARGET" + upx -t "$TARGET" + AFTER_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]') + if [ "$AFTER_BYTES" -lt "$BEFORE_BYTES" ]; then + SAVED_BYTES=$((BEFORE_BYTES - AFTER_BYTES)) + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" -v s="$SAVED_BYTES" 'BEGIN { printf "✅ Linux UPX 压缩完成:%.2fMB -> %.2fMB,减少 %.2fMB\n", b/1024/1024, a/1024/1024, s/1024/1024 }' + else + awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" 'BEGIN { printf "ℹ️ Linux UPX 压缩完成:%.2fMB -> %.2fMB\n", b/1024/1024, a/1024/1024 }' + fi tar -czvf "../../artifacts/$TAR_NAME" "$TARGET" sha256sum "../../artifacts/$TAR_NAME" > "../../artifacts/$TAR_NAME.sha256" diff --git a/build-driver-agents.sh b/build-driver-agents.sh new file mode 100755 index 0000000..e3734d2 --- /dev/null +++ b/build-driver-agents.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +DEFAULT_DRIVERS=(mariadb doris sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase mongodb tdengine clickhouse) + +usage() { + cat <<'EOF' +用法: + ./build-driver-agents.sh [选项] + +选项: + --drivers <列表> 指定驱动列表(逗号分隔),例如:kingbase,mongodb + --platform + 目标平台,默认使用当前 Go 环境(go env GOOS/GOARCH) + --out-dir <目录> 输出目录根路径,默认:dist/driver-agents + --bundle-name <文件名> 驱动总包 zip 名称,默认:GoNavi-DriverAgents.zip + --strict 任一驱动构建失败即中断(默认失败后继续,最后汇总) + -h, --help 显示帮助 + +示例: + ./build-driver-agents.sh + ./build-driver-agents.sh --drivers kingbase + ./build-driver-agents.sh --platform windows/amd64 --drivers kingbase,mongodb +EOF +} + +normalize_driver() { + local name + name="$(echo "${1:-}" | tr '[:upper:]' '[:lower:]' | xargs)" + case "$name" in + doris|diros) echo "doris" ;; + mariadb|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|mongodb|tdengine|clickhouse) + echo "$name" + ;; + *) + return 1 + ;; + esac +} + +build_driver_name() { + case "$1" in + doris) echo "diros" ;; + *) echo "$1" ;; + esac +} + +platform_dir_name() { + case "$1" in + windows) echo "Windows" ;; + darwin) echo "MacOS" ;; + linux) echo "Linux" ;; + *) echo "Unknown" ;; + esac +} + +driver_csv="" +target_platform="" +out_root="dist/driver-agents" +bundle_name="GoNavi-DriverAgents.zip" +strict_mode="false" + +while [[ $# -gt 0 ]]; do + case "$1" in + --drivers) + driver_csv="${2:-}" + shift 2 + ;; + --platform) + target_platform="${2:-}" + shift 2 + ;; + --out-dir) + out_root="${2:-}" + shift 2 + ;; + --bundle-name) + bundle_name="${2:-}" + shift 2 + ;; + --strict) + strict_mode="true" + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "❌ 未知参数:$1" + usage + exit 1 + ;; + esac +done + +if ! command -v go >/dev/null 2>&1; then + echo "❌ 未找到 Go,请先安装 Go 并确保 go 在 PATH 中。" + exit 1 +fi + +if [[ -z "$target_platform" ]]; then + target_platform="$(go env GOOS)/$(go env GOARCH)" +fi + +if [[ "$target_platform" != */* ]]; then + echo "❌ --platform 参数格式错误,应为 GOOS/GOARCH,例如 darwin/arm64" + exit 1 +fi + +goos="${target_platform%%/*}" +goarch="${target_platform##*/}" +platform_key="${goos}-${goarch}" +platform_dir="$(platform_dir_name "$goos")" + +declare -a drivers=() +if [[ -n "$driver_csv" ]]; then + IFS=',' read -r -a raw_drivers <<<"$driver_csv" + for item in "${raw_drivers[@]}"; do + normalized="$(normalize_driver "$item")" || { + echo "❌ 不支持的驱动:$item" + exit 1 + } + drivers+=("$normalized") + done +else + drivers=("${DEFAULT_DRIVERS[@]}") +fi + +output_dir="${out_root%/}/${platform_key}" +bundle_stage_dir="$(mktemp -d "${TMPDIR:-/tmp}/gonavi-driver-bundle.XXXXXX")" +bundle_platform_dir="$bundle_stage_dir/$platform_dir" + +cleanup() { + rm -rf "$bundle_stage_dir" +} +trap cleanup EXIT + +mkdir -p "$output_dir" "$bundle_platform_dir" +output_dir_abs="$(cd "$output_dir" && pwd)" +bundle_zip_path="$output_dir_abs/$bundle_name" + +declare -a built_assets=() +declare -a failed_drivers=() +declare -a skipped_drivers=() + +echo "🚀 开始构建 optional-driver-agent" +echo " 平台:$goos/$goarch" +echo " 输出目录:$output_dir_abs" +echo " 驱动列表:${drivers[*]}" + +for driver in "${drivers[@]}"; do + if [[ "$driver" == "duckdb" && "$goos" == "windows" && "$goarch" != "amd64" ]]; then + echo "⚠️ 跳过 duckdb(仅支持 windows/amd64)" + skipped_drivers+=("$driver") + continue + fi + + build_driver="$(build_driver_name "$driver")" + tag="gonavi_${build_driver}_driver" + asset_name="${driver}-driver-agent-${goos}-${goarch}" + if [[ "$goos" == "windows" ]]; then + asset_name="${asset_name}.exe" + fi + output_path="$output_dir_abs/$asset_name" + + cgo_enabled=0 + if [[ "$driver" == "duckdb" ]]; then + cgo_enabled=1 + fi + + echo "🔧 构建 $driver -> $asset_name (tag=$tag, CGO_ENABLED=$cgo_enabled)" + set +e + CGO_ENABLED="$cgo_enabled" GOOS="$goos" GOARCH="$goarch" GOTOOLCHAIN=auto \ + go build -tags "$tag" -trimpath -ldflags "-s -w" -o "$output_path" ./cmd/optional-driver-agent + build_exit=$? + set -e + + if [[ $build_exit -ne 0 ]]; then + echo "❌ 构建失败:$driver" + failed_drivers+=("$driver") + if [[ "$strict_mode" == "true" ]]; then + exit $build_exit + fi + continue + fi + + cp "$output_path" "$bundle_platform_dir/$asset_name" + built_assets+=("$asset_name") +done + +if [[ ${#built_assets[@]} -eq 0 ]]; then + echo "❌ 未成功构建任何驱动代理。" + exit 1 +fi + +rm -f "$bundle_zip_path" +if command -v zip >/dev/null 2>&1; then + ( + cd "$bundle_stage_dir" + zip -qry "$bundle_zip_path" "$platform_dir" + ) +elif command -v ditto >/dev/null 2>&1; then + ( + cd "$bundle_stage_dir" + ditto -c -k --sequesterRsrc --keepParent "$platform_dir" "$bundle_zip_path" + ) +else + echo "❌ 未找到 zip/ditto,无法生成驱动总包 zip。" + exit 1 +fi + +echo "" +echo "✅ 构建完成" +echo " 单文件输出目录:$output_dir_abs" +echo " 驱动总包:$bundle_zip_path" +echo " 已构建:${built_assets[*]}" +if [[ ${#skipped_drivers[@]} -gt 0 ]]; then + echo " 已跳过:${skipped_drivers[*]}" +fi +if [[ ${#failed_drivers[@]} -gt 0 ]]; then + echo "⚠️ 构建失败驱动:${failed_drivers[*]}" + exit 2 +fi diff --git a/build-release.sh b/build-release.sh index 4be9a67..22fe7c8 100755 --- a/build-release.sh +++ b/build-release.sh @@ -20,6 +20,75 @@ RED='\033[0;31m' YELLOW='\033[1;33m' NC='\033[0m' +get_file_size_bytes() { + local target="$1" + if [ ! -f "$target" ]; then + echo 0 + return + fi + if stat -f%z "$target" >/dev/null 2>&1; then + stat -f%z "$target" + return + fi + if stat -c%s "$target" >/dev/null 2>&1; then + stat -c%s "$target" + return + fi + wc -c <"$target" | tr -d '[:space:]' +} + +format_size_mb() { + local bytes="${1:-0}" + awk -v b="$bytes" 'BEGIN { printf "%.2fMB", b / 1024 / 1024 }' +} + +try_compress_binary_with_upx() { + local exe_path="$1" + local label="$2" + if [ ! -f "$exe_path" ]; then + echo -e "${RED} ❌ 未找到 ${label} 文件:$exe_path${NC}" + exit 1 + fi + + if ! command -v upx >/dev/null 2>&1; then + echo -e "${RED} ❌ 未找到 upx,${label} 必须进行压缩后才能继续打包。${NC}" + case "$(uname -s)" in + Darwin) + echo " 安装命令: brew install upx" + ;; + Linux) + echo " 安装命令: sudo apt-get install -y upx-ucl (或对应发行版包管理器)" + ;; + esac + exit 1 + fi + + local before_bytes after_bytes + before_bytes=$(get_file_size_bytes "$exe_path") + echo " 🗜️ 正在使用 UPX 压缩 ${label}..." + if upx --best --lzma --force "$exe_path" >/dev/null 2>&1; then + if ! upx -t "$exe_path" >/dev/null 2>&1; then + echo -e "${RED} ❌ UPX 校验失败:${label}${NC}" + exit 1 + fi + after_bytes=$(get_file_size_bytes "$exe_path") + if [ "$after_bytes" -lt "$before_bytes" ]; then + local saved_bytes=$((before_bytes - after_bytes)) + echo " ✅ UPX 压缩完成: $(format_size_mb "$before_bytes") -> $(format_size_mb "$after_bytes"),减少 $(format_size_mb "$saved_bytes")" + else + echo " ℹ️ UPX 压缩完成: $(format_size_mb "$before_bytes") -> $(format_size_mb "$after_bytes")" + fi + else + echo -e "${RED} ❌ UPX 压缩失败:${label}${NC}" + exit 1 + fi +} + +MAC_VOLICON_PATH="build/darwin/icon.icns" +if [ ! -f "$MAC_VOLICON_PATH" ]; then + MAC_VOLICON_PATH="" +fi + echo -e "${GREEN}🚀 开始构建 $APP_NAME $VERSION...${NC}" # 清理并创建输出目录 @@ -36,47 +105,101 @@ if [ $? -eq 0 ]; then # 移动 .app 到 dist mv "$APP_SRC" "$DIST_DIR/$APP_DEST_NAME" + + APP_BIN_PATH=$(find "$DIST_DIR/$APP_DEST_NAME/Contents/MacOS" -maxdepth 1 -type f -print -quit) + if [ -n "$APP_BIN_PATH" ] && [ -f "$APP_BIN_PATH" ]; then + try_compress_binary_with_upx "$APP_BIN_PATH" "macOS arm64 应用主程序" + else + echo -e "${RED} ❌ 未找到 macOS arm64 主程序文件,无法执行 UPX 压缩。${NC}" + exit 1 + fi - # 创建 DMG - if command -v create-dmg &> /dev/null; then - echo " 📦 正在打包 DMG (arm64)..." - # 移除已存在的 DMG (以防万一) - rm -f "$DIST_DIR/$DMG_NAME" - - create-dmg \ - --volname "${APP_NAME} ${VERSION}" \ - --volicon "build/appicon.icns" \ - --window-pos 200 120 \ - --window-size 800 400 \ - --icon-size 100 \ - --icon "$APP_DEST_NAME" 200 190 \ - --hide-extension "$APP_DEST_NAME" \ - --app-drop-link 600 185 \ - "$DIST_DIR/$DMG_NAME" \ - "$DIST_DIR/$APP_DEST_NAME" - - # 检查是否生成了 rw.* 的临时文件并重命名 (create-dmg 有时会有此行为) - if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then - RW_FILE=$(find "$DIST_DIR" -name "rw.*.dmg" -print -quit) - if [ -n "$RW_FILE" ]; then - echo -e "${YELLOW} ⚠️ 检测到临时文件名,正在重命名...${NC}" - mv "$RW_FILE" "$DIST_DIR/$DMG_NAME" - fi + # Ad-hoc 代码签名(无 Apple Developer 账号时防止 Gatekeeper 报已损坏) + echo " 🔏 正在对 .app 进行 ad-hoc 签名 (arm64)..." + codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME" + + # 创建 DMG + if command -v create-dmg &> /dev/null; then + echo " 📦 正在打包 DMG (arm64)..." + # 移除已存在的 DMG (以防万一) + rm -f "$DIST_DIR/$DMG_NAME" + # create-dmg 的 source 需要是“包含 .app 的目录”,不能直接传 .app 路径。 + STAGE_DIR=$(mktemp -d "$DIST_DIR/.dmg-stage-${APP_NAME}-${VERSION}-arm64.XXXXXX") + if [ -z "$STAGE_DIR" ] || [ ! -d "$STAGE_DIR" ]; then + echo -e "${RED} ❌ 创建 DMG 临时目录失败,跳过 DMG 打包。${NC}" + else + if command -v ditto &> /dev/null; then + ditto "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + else + cp -R "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + fi + + # --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。 + CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe) + if [ -n "$MAC_VOLICON_PATH" ]; then + CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH") + else + echo -e "${YELLOW} ⚠️ 未找到 macOS 卷图标 (build/darwin/icon.icns),跳过 --volicon。${NC}" fi - # 删除中间的 .app 文件,保持目录整洁 - rm -rf "$DIST_DIR/$APP_DEST_NAME" - - if [ -f "$DIST_DIR/$DMG_NAME" ]; then - echo " ✅ 已生成 $DMG_NAME" - else - echo -e "${RED} ❌ DMG 生成失败,请检查 create-dmg 输出。${NC}" + create-dmg "${CREATE_DMG_ARGS[@]}" \ + --window-pos 200 120 \ + --window-size 800 400 \ + --icon-size 100 \ + --icon "$APP_DEST_NAME" 200 190 \ + --hide-extension "$APP_DEST_NAME" \ + --app-drop-link 600 185 \ + "$DIST_DIR/$DMG_NAME" \ + "$STAGE_DIR" + + CREATE_DMG_EXIT_CODE=$? + rm -rf "$STAGE_DIR" + + if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then + echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}" + else + # create-dmg 可能会在失败时遗留 rw.*.dmg 中间产物;不要直接当作最终产物使用 + if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then + RW_FILE=$(find "$DIST_DIR" -maxdepth 1 -name "rw.*.dmg" -print -quit) + if [ -n "$RW_FILE" ]; then + echo -e "${YELLOW} ⚠️ 检测到 create-dmg 中间产物: $(basename "$RW_FILE"),正在转换为可分发 DMG...${NC}" + hdiutil convert "$RW_FILE" -format UDZO -o "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1 + rm -f "$RW_FILE" + fi + fi + + # 防御性:即使生成了目标文件,也要确保不是 UDRW(UDRW 在 Finder 下可能表现为“已损坏/无法打开”) + if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then + DMG_FORMAT=$(hdiutil imageinfo "$DIST_DIR/$DMG_NAME" 2>/dev/null | awk -F': ' '/^Format:/{print $2; exit}') + if [ "$DMG_FORMAT" = "UDRW" ]; then + echo -e "${YELLOW} ⚠️ 检测到 UDRW(可写原始映像),正在转换为 UDZO...${NC}" + TMP_UDZO="$DIST_DIR/.tmp.$DMG_NAME" + rm -f "$TMP_UDZO" + hdiutil convert "$DIST_DIR/$DMG_NAME" -format UDZO -o "$TMP_UDZO" >/dev/null 2>&1 && mv "$TMP_UDZO" "$DIST_DIR/$DMG_NAME" + fi + fi + + if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then + hdiutil verify "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1 + if [ $? -ne 0 ]; then + echo -e "${RED} ❌ DMG 校验失败,保留 .app 以便排查。${NC}" + else + # 删除中间的 .app 文件,保持目录整洁 + rm -rf "$DIST_DIR/$APP_DEST_NAME" + echo " ✅ 已生成 $DMG_NAME" + fi + fi fi - else - echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具,跳过 DMG 打包,仅保留 .app。${NC}" - echo " 安装命令: brew install create-dmg" - fi -else + + if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then + echo -e "${RED} ❌ DMG 生成失败,请检查 create-dmg 输出。${NC}" + fi + fi + else + echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具,跳过 DMG 打包,仅保留 .app。${NC}" + echo " 安装命令: brew install create-dmg" + fi + else echo -e "${RED} ❌ macOS arm64 构建失败。${NC}" fi @@ -89,44 +212,96 @@ if [ $? -eq 0 ]; then DMG_NAME="${APP_NAME}-${VERSION}-mac-amd64.dmg" mv "$APP_SRC" "$DIST_DIR/$APP_DEST_NAME" - - if command -v create-dmg &> /dev/null; then - echo " 📦 正在打包 DMG (amd64)..." - rm -f "$DIST_DIR/$DMG_NAME" - - create-dmg \ - --volname "${APP_NAME} ${VERSION}" \ - --volicon "build/appicon.icns" \ - --window-pos 200 120 \ - --window-size 800 400 \ - --icon-size 100 \ - --icon "$APP_DEST_NAME" 200 190 \ - --hide-extension "$APP_DEST_NAME" \ - --app-drop-link 600 185 \ - "$DIST_DIR/$DMG_NAME" \ - "$DIST_DIR/$APP_DEST_NAME" - # 检查是否生成了 rw.* 的临时文件并重命名 - if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then - RW_FILE=$(find "$DIST_DIR" -name "rw.*.dmg" -print -quit) - if [ -n "$RW_FILE" ]; then - echo -e "${YELLOW} ⚠️ 检测到临时文件名,正在重命名...${NC}" - mv "$RW_FILE" "$DIST_DIR/$DMG_NAME" - fi - fi - - rm -rf "$DIST_DIR/$APP_DEST_NAME" - - if [ -f "$DIST_DIR/$DMG_NAME" ]; then - echo " ✅ 已生成 $DMG_NAME" - else - echo -e "${RED} ❌ DMG 生成失败。${NC}" - fi + APP_BIN_PATH=$(find "$DIST_DIR/$APP_DEST_NAME/Contents/MacOS" -maxdepth 1 -type f -print -quit) + if [ -n "$APP_BIN_PATH" ] && [ -f "$APP_BIN_PATH" ]; then + try_compress_binary_with_upx "$APP_BIN_PATH" "macOS amd64 应用主程序" else - echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具。${NC}" + echo -e "${RED} ❌ 未找到 macOS amd64 主程序文件,无法执行 UPX 压缩。${NC}" + exit 1 fi -else - echo -e "${RED} ❌ macOS amd64 构建失败。${NC}" + + # Ad-hoc 代码签名 + echo " 🔏 正在对 .app 进行 ad-hoc 签名 (amd64)..." + codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME" + + if command -v create-dmg &> /dev/null; then + echo " 📦 正在打包 DMG (amd64)..." + rm -f "$DIST_DIR/$DMG_NAME" + # create-dmg 的 source 需要是“包含 .app 的目录”,不能直接传 .app 路径。 + STAGE_DIR=$(mktemp -d "$DIST_DIR/.dmg-stage-${APP_NAME}-${VERSION}-amd64.XXXXXX") + if [ -z "$STAGE_DIR" ] || [ ! -d "$STAGE_DIR" ]; then + echo -e "${RED} ❌ 创建 DMG 临时目录失败,跳过 DMG 打包。${NC}" + else + if command -v ditto &> /dev/null; then + ditto "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + else + cp -R "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + fi + + # --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。 + CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe) + if [ -n "$MAC_VOLICON_PATH" ]; then + CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH") + else + echo -e "${YELLOW} ⚠️ 未找到 macOS 卷图标 (build/darwin/icon.icns),跳过 --volicon。${NC}" + fi + + create-dmg "${CREATE_DMG_ARGS[@]}" \ + --window-pos 200 120 \ + --window-size 800 400 \ + --icon-size 100 \ + --icon "$APP_DEST_NAME" 200 190 \ + --hide-extension "$APP_DEST_NAME" \ + --app-drop-link 600 185 \ + "$DIST_DIR/$DMG_NAME" \ + "$STAGE_DIR" + + CREATE_DMG_EXIT_CODE=$? + rm -rf "$STAGE_DIR" + + if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then + echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}" + else + if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then + RW_FILE=$(find "$DIST_DIR" -maxdepth 1 -name "rw.*.dmg" -print -quit) + if [ -n "$RW_FILE" ]; then + echo -e "${YELLOW} ⚠️ 检测到 create-dmg 中间产物: $(basename "$RW_FILE"),正在转换为可分发 DMG...${NC}" + hdiutil convert "$RW_FILE" -format UDZO -o "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1 + rm -f "$RW_FILE" + fi + fi + + if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then + DMG_FORMAT=$(hdiutil imageinfo "$DIST_DIR/$DMG_NAME" 2>/dev/null | awk -F': ' '/^Format:/{print $2; exit}') + if [ "$DMG_FORMAT" = "UDRW" ]; then + echo -e "${YELLOW} ⚠️ 检测到 UDRW(可写原始映像),正在转换为 UDZO...${NC}" + TMP_UDZO="$DIST_DIR/.tmp.$DMG_NAME" + rm -f "$TMP_UDZO" + hdiutil convert "$DIST_DIR/$DMG_NAME" -format UDZO -o "$TMP_UDZO" >/dev/null 2>&1 && mv "$TMP_UDZO" "$DIST_DIR/$DMG_NAME" + fi + fi + + if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then + hdiutil verify "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1 + if [ $? -ne 0 ]; then + echo -e "${RED} ❌ DMG 校验失败,保留 .app 以便排查。${NC}" + else + rm -rf "$DIST_DIR/$APP_DEST_NAME" + echo " ✅ 已生成 $DMG_NAME" + fi + fi + fi + + if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then + echo -e "${RED} ❌ DMG 生成失败。${NC}" + fi + fi + else + echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具。${NC}" + fi + else + echo -e "${RED} ❌ macOS amd64 构建失败。${NC}" fi # --- Windows AMD64 构建 --- @@ -134,7 +309,9 @@ echo -e "${GREEN}🪟 正在构建 Windows (amd64)...${NC}" if command -v x86_64-w64-mingw32-gcc &> /dev/null; then wails build -platform windows/amd64 -clean -ldflags "$LDFLAGS" if [ $? -eq 0 ]; then - mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe" + TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe" + mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$TARGET_EXE" + try_compress_binary_with_upx "$TARGET_EXE" "Windows amd64 可执行文件" echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-amd64.exe" else echo -e "${RED} ❌ Windows amd64 构建失败。${NC}" @@ -148,7 +325,9 @@ echo -e "${GREEN}🪟 正在构建 Windows (arm64)...${NC}" if command -v aarch64-w64-mingw32-gcc &> /dev/null; then wails build -platform windows/arm64 -clean -ldflags "$LDFLAGS" if [ $? -eq 0 ]; then - mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe" + TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe" + mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$TARGET_EXE" + try_compress_binary_with_upx "$TARGET_EXE" "Windows arm64 可执行文件" echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-arm64.exe" else echo -e "${RED} ❌ Windows arm64 构建失败。${NC}" @@ -168,8 +347,10 @@ if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "x86_64" ]; then # 本机 Linux amd64,直接构建 wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS" if [ $? -eq 0 ]; then - mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64" - chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64" + TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64" + mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN" + chmod +x "$TARGET_LINUX_BIN" + try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux amd64 可执行文件" # 打包为 tar.gz cd "$DIST_DIR" tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64" @@ -186,8 +367,10 @@ elif command -v x86_64-linux-gnu-gcc &> /dev/null; then export CGO_ENABLED=1 wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS" if [ $? -eq 0 ]; then - mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64" - chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64" + TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64" + mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN" + chmod +x "$TARGET_LINUX_BIN" + try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux amd64 可执行文件" cd "$DIST_DIR" tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64" rm "${APP_NAME}-${VERSION}-linux-amd64" @@ -208,8 +391,10 @@ if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "aarch64" ]; then # 本机 Linux arm64,直接构建 wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS" if [ $? -eq 0 ]; then - mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64" - chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64" + TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64" + mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN" + chmod +x "$TARGET_LINUX_BIN" + try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux arm64 可执行文件" cd "$DIST_DIR" tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64" rm "${APP_NAME}-${VERSION}-linux-arm64" @@ -225,8 +410,10 @@ elif command -v aarch64-linux-gnu-gcc &> /dev/null; then export CGO_ENABLED=1 wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS" if [ $? -eq 0 ]; then - mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64" - chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64" + TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64" + mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN" + chmod +x "$TARGET_LINUX_BIN" + try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux arm64 可执行文件" cd "$DIST_DIR" tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64" rm "${APP_NAME}-${VERSION}-linux-arm64" diff --git a/frontend/index.html b/frontend/index.html index 127af4b..b596c58 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -5,6 +5,23 @@ GoNavi +
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c0c5436..ce1832e 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -93,27 +93,39 @@ function App() { // 同步 macOS 窗口透明度:opacity=1.0 且 blur=0 时关闭 NSVisualEffectView, // 避免 GPU 持续计算窗口背后的模糊合成 useEffect(() => { - void SetWindowTranslucency(resolvedAppearance.opacity, resolvedAppearance.blur).catch(() => undefined); + try { + void SetWindowTranslucency(resolvedAppearance.opacity, resolvedAppearance.blur).catch(() => undefined); + } catch(e) { /* ignore */ } }, [resolvedAppearance.blur, resolvedAppearance.opacity]); useEffect(() => { let cancelled = false; - Environment() - .then((env) => { - if (cancelled) return; - const platform = String(env?.platform || '').toLowerCase(); - setRuntimePlatform(platform); - setIsLinuxRuntime(platform === 'linux'); - }) - .catch(() => { - if (cancelled) return; - const platform = detectNavigatorPlatform(); - const normalized = /linux/i.test(platform) - ? 'linux' - : (/mac/i.test(platform) ? 'darwin' : (/win/i.test(platform) ? 'windows' : '')); - setRuntimePlatform(normalized); - setIsLinuxRuntime(normalized === 'linux'); - }); + try { + Environment() + .then((env) => { + if (cancelled) return; + const platform = String(env?.platform || '').toLowerCase(); + setRuntimePlatform(platform); + setIsLinuxRuntime(platform === 'linux'); + }) + .catch(() => { + if (cancelled) return; + const platform = detectNavigatorPlatform(); + const normalized = /linux/i.test(platform) + ? 'linux' + : (/mac/i.test(platform) ? 'darwin' : (/win/i.test(platform) ? 'windows' : '')); + setRuntimePlatform(normalized); + setIsLinuxRuntime(normalized === 'linux'); + }); + } catch(e) { + if (cancelled) return; + const platform = detectNavigatorPlatform(); + const normalized = /linux/i.test(platform) + ? 'linux' + : (/mac/i.test(platform) ? 'darwin' : (/win/i.test(platform) ? 'windows' : '')); + setRuntimePlatform(normalized); + setIsLinuxRuntime(normalized === 'linux'); + } return () => { cancelled = true; }; @@ -156,32 +168,36 @@ function App() { const enabledForBackend = globalProxy.enabled && !invalidWhenEnabled; let cancelled = false; - ConfigureGlobalProxy(enabledForBackend, { - type: globalProxy.type, - host, - port: portValid ? port : (globalProxy.type === 'http' ? 8080 : 1080), - user: String(globalProxy.user || '').trim(), - password: globalProxy.password || '', - }) - .then((res) => { - if (cancelled || res?.success) { - return; - } - void message.error({ - content: '全局代理配置失败: ' + (res?.message || '未知错误'), - key: 'global-proxy-sync-error', - }); + try { + ConfigureGlobalProxy(enabledForBackend, { + type: globalProxy.type, + host, + port: portValid ? port : (globalProxy.type === 'http' ? 8080 : 1080), + user: String(globalProxy.user || '').trim(), + password: globalProxy.password || '', }) - .catch((err) => { - if (cancelled) { - return; - } - const errMsg = err instanceof Error ? err.message : String(err || '未知错误'); - void message.error({ - content: '全局代理配置失败: ' + errMsg, - key: 'global-proxy-sync-error', + .then((res) => { + if (cancelled || res?.success) { + return; + } + void message.error({ + content: '全局代理配置失败: ' + (res?.message || '未知错误'), + key: 'global-proxy-sync-error', + }); + }) + .catch((err) => { + if (cancelled) { + return; + } + const errMsg = err instanceof Error ? err.message : String(err || '未知错误'); + void message.error({ + content: '全局代理配置失败: ' + errMsg, + key: 'global-proxy-sync-error', + }); }); - }); + } catch (e) { + console.warn("Wails API: ConfigureGlobalProxy unavailable", e); + } return () => { cancelled = true; @@ -238,13 +254,18 @@ function App() { return; } // 优先尝试全屏,若当前平台/时机不生效,后续走最大化兜底。 - await WindowFullscreen(); - await new Promise((resolve) => window.setTimeout(resolve, settleDelayMs)); - if (await checkStartupPreferenceApplied()) { - return; + try { + await WindowFullscreen(); + await new Promise((resolve) => window.setTimeout(resolve, settleDelayMs)); + if (await checkStartupPreferenceApplied()) { + return; + } + await WindowMaximise(); + await new Promise((resolve) => window.setTimeout(resolve, settleDelayMs)); + } catch (e) { + console.warn("Wails Window APIs unavailable", e); } - await WindowMaximise(); - await new Promise((resolve) => window.setTimeout(resolve, settleDelayMs)); + if (await checkStartupPreferenceApplied()) { return; } @@ -315,11 +336,15 @@ function App() { } const nudgedWidth = width > 480 ? width - 1 : width + 1; - WindowSetSize(nudgedWidth, height); - await wait(28); - WindowSetSize(width, height); + try { + WindowSetSize(nudgedWidth, height); + await wait(28); + WindowSetSize(width, height); + } catch(e) {} window.dispatchEvent(new Event('resize')); lastFixAt = Date.now(); + } catch(e) { + console.warn("Wails Window APIs unavailable in fixWindowScaleIfNeeded", e); } finally { inFlight = false; } @@ -649,7 +674,12 @@ function App() { total: info.assetSize || 0, message: '' }); - const res = await (window as any).go.app.App.DownloadUpdate(); + let res: any = null; + try { + res = await (window as any).go.app.App.DownloadUpdate(); + } catch (e) { + console.warn("Wails API: DownloadUpdate unavailable", e); + } updateDownloadInFlightRef.current = false; if (res?.success) { const resultData = (res?.data || {}) as UpdateDownloadResultData; @@ -1050,7 +1080,7 @@ function App() { if (target?.closest('[data-no-titlebar-toggle="true"]')) { return; } - WindowToggleMaximise(); + try { WindowToggleMaximise(); } catch(e) {} }; // Sidebar Resizing @@ -1158,7 +1188,9 @@ function App() { }, [checkForUpdates]); useEffect(() => { - const offDownloadProgress = EventsOn('update:download-progress', (event: UpdateDownloadProgressEvent) => { + let offDownloadProgress: any = null; + try { + offDownloadProgress = EventsOn('update:download-progress', (event: UpdateDownloadProgressEvent) => { if (!event) return; const status = event.status || 'downloading'; const nextStatus: 'idle' | 'start' | 'downloading' | 'done' | 'error' = @@ -1181,8 +1213,11 @@ function App() { message: String(event.message || '') })); }); + } catch (e) { + console.warn("Wails API: EventsOn unavailable", e); + } return () => { - offDownloadProgress(); + if (offDownloadProgress) offDownloadProgress(); }; }, []); diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index bf7414b..1f9d9b5 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -568,6 +568,7 @@ const ConnectionModal: React.FC<{ return { host: primary?.host || 'localhost', port: primary?.port || 6379, + user: parsed.username || '', password: parsed.password || '', useSSL: isRediss, sslMode: isRediss ? (skipVerify ? 'skip-verify' : 'required') : 'disable', @@ -823,8 +824,15 @@ const ConnectionModal: React.FC<{ if (hosts.length > 1 || values.redisTopology === 'cluster') { params.set('topology', 'cluster'); } + const redisUser = String(values.user || '').trim(); const redisPassword = String(values.password || ''); - const redisAuth = redisPassword ? `:${encodeURIComponent(redisPassword)}@` : ''; + let redisAuth = ''; + if (redisUser || redisPassword) { + const encodedPassword = redisPassword ? encodeURIComponent(redisPassword) : ''; + redisAuth = redisUser + ? `${encodeURIComponent(redisUser)}${redisPassword ? `:${encodedPassword}` : ''}@` + : `:${encodedPassword}@`; + } const redisDB = Number.isFinite(Number(values.redisDB)) ? Math.max(0, Math.min(15, Math.trunc(Number(values.redisDB)))) : 0; @@ -1041,6 +1049,12 @@ const ConnectionModal: React.FC<{ useEffect(() => { if (open) { + setLoading(false); + testInFlightRef.current = false; + if (testTimerRef.current !== null) { + window.clearTimeout(testTimerRef.current); + testTimerRef.current = null; + } setTestResult(null); // Reset test result setTestErrorLogOpen(false); setDbList([]); @@ -1232,6 +1246,22 @@ const ConnectionModal: React.FC<{ }, 0); }; + const withClientTimeout = async (promise: Promise, timeoutMs: number, timeoutMessage: string): Promise => { + let timer: number | null = null; + try { + return await Promise.race([ + promise, + new Promise((_, reject) => { + timer = window.setTimeout(() => reject(new Error(timeoutMessage)), timeoutMs); + }), + ]); + } finally { + if (timer !== null) { + window.clearTimeout(timer); + } + } + }; + const buildTestFailureMessage = (reason: unknown, fallback: string) => { const text = String(reason ?? '').trim(); const normalized = text && text !== 'undefined' && text !== 'null' ? text : fallback; @@ -1254,12 +1284,21 @@ const ConnectionModal: React.FC<{ setLoading(true); setTestResult(null); const config = await buildConfig(values, false); + const timeoutSecondsRaw = Number(values.timeout); + const timeoutSeconds = Number.isFinite(timeoutSecondsRaw) && timeoutSecondsRaw > 0 + ? Math.min(timeoutSecondsRaw, MAX_TIMEOUT_SECONDS) + : 30; + const rpcTimeoutMs = (timeoutSeconds + 5) * 1000; // Use different API for Redis const isRedisType = values.type === 'redis'; - const res = isRedisType - ? await RedisConnect(config as any) - : await TestConnection(config as any); + const res = await withClientTimeout( + isRedisType + ? RedisConnect(config as any) + : TestConnection(config as any), + rpcTimeoutMs, + `连接测试超时(>${timeoutSeconds} 秒),请检查网络/代理/SSH配置后重试` + ); if (res.success) { setTestResult({ type: 'success', message: res.message }); @@ -1267,7 +1306,11 @@ const ConnectionModal: React.FC<{ setRedisDbList(Array.from({ length: 16 }, (_, i) => i)); } else { // Other databases: fetch database list - const dbRes = await DBGetDatabases(config as any); + const dbRes = await withClientTimeout( + DBGetDatabases(config as any), + rpcTimeoutMs, + `连接成功但拉取数据库列表超时(>${timeoutSeconds} 秒)` + ); if (dbRes.success) { const dbRows = Array.isArray(dbRes.data) ? dbRes.data : []; const dbs = dbRows @@ -1368,6 +1411,16 @@ const ConnectionModal: React.FC<{ const defaultPort = getDefaultPortByType(type); const isFileDbType = isFileDatabaseType(type); const sslCapableType = supportsSSLForType(type); + + // Redis 默认不展示用户名字段;若 URI 可解析则以 URI 为准覆盖 user, + // 同时清理历史默认值 root,避免 go-redis 发送 ACL AUTH(user, pass) 导致 WRONGPASS。 + if (type === 'redis') { + if (parsedUriValues && Object.prototype.hasOwnProperty.call(parsedUriValues, 'user')) { + mergedValues.user = String((parsedUriValues as any).user || ''); + } else if (String(mergedValues.user || '').trim() === 'root') { + mergedValues.user = ''; + } + } const sslModeRaw = String(mergedValues.sslMode || 'preferred').trim().toLowerCase(); const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' = sslModeRaw === 'required' ? 'required' @@ -1554,12 +1607,13 @@ const ConnectionModal: React.FC<{ }; }; - const handleTypeSelect = async (type: string) => { - const unavailableReason = await resolveDriverUnavailableReason(type); - if (unavailableReason) { - const normalized = normalizeDriverType(type); - const driverName = driverStatusMap[normalized]?.name || type; - setTypeSelectWarning({ driverName, reason: unavailableReason }); + const handleTypeSelect = (type: string) => { + const normalized = normalizeDriverType(type); + const snapshot = driverStatusMap[normalized]; + if (snapshot && !snapshot.connectable) { + const driverName = snapshot.name || type; + const reason = snapshot.message || `${driverName} 驱动未安装启用,请先在驱动管理中安装`; + setTypeSelectWarning({ driverName, reason }); return; } setTypeSelectWarning(null); @@ -1618,7 +1672,11 @@ const ConnectionModal: React.FC<{ redisDB: 0, }); } else if (type !== 'custom') { - const defaultUser = type === 'clickhouse' ? 'default' : 'root'; + const defaultUser = type === 'clickhouse' + ? 'default' + : type === 'redis' + ? '' + : 'root'; const sslCapableType = supportsSSLForType(type); setUseSSL(false); setUseHttpTunnel(false); @@ -1657,6 +1715,10 @@ const ConnectionModal: React.FC<{ setMongoMembers([]); setStep(2); + + if (!driverStatusLoaded || !snapshot) { + void refreshDriverStatus(); + } }; const isFileDb = isFileDatabaseType(dbType); @@ -1829,7 +1891,6 @@ const ConnectionModal: React.FC<{ > {isFileDb ? ( diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index 56264bd..0a35d9f 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -1,14 +1,32 @@ +// cspell:ignore anticon sqls uuidv uuidv4 hscroll import React, { useState, useEffect, useRef, useContext, useMemo, useCallback } from 'react'; import { createPortal } from 'react-dom'; import { Table, message, Input, Button, Dropdown, MenuProps, Form, Pagination, Select, Modal, Checkbox, Segmented, Tooltip, Popover } from 'antd'; -import type { SortOrder } from 'antd/es/table/interface'; +import type { SortOrder, ColumnType } from 'antd/es/table/interface'; import { ReloadOutlined, ImportOutlined, ExportOutlined, DownOutlined, PlusOutlined, DeleteOutlined, SaveOutlined, UndoOutlined, FilterOutlined, CloseOutlined, ConsoleSqlOutlined, FileTextOutlined, CopyOutlined, ClearOutlined, EditOutlined, VerticalAlignBottomOutlined, LeftOutlined, RightOutlined } from '@ant-design/icons'; import Editor from '@monaco-editor/react'; +import { + DndContext, + DragEndEvent, + PointerSensor, + MouseSensor, + TouchSensor, + useSensor, + useSensors, + closestCenter +} from '@dnd-kit/core'; +import { + SortableContext, + useSortable, + horizontalListSortingStrategy, + arrayMove +} from '@dnd-kit/sortable'; +import { CSS } from '@dnd-kit/utilities'; import { ImportData, ExportTable, ExportData, ExportQuery, ApplyChanges, DBGetColumns } from '../../wailsjs/go/app/App'; import ImportPreviewModal from './ImportPreviewModal'; import { useStore } from '../store'; import type { ColumnDefinition } from '../types'; -import { v4 as uuidv4 } from 'uuid'; +import { v4 as generateUuid } from 'uuid'; import 'react-resizable/css/styles.css'; import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, escapeLiteral, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql'; import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; @@ -227,13 +245,9 @@ const shouldOpenModalEditor = (val: any): boolean => { if (typeof val === 'string') { if (val.length > INLINE_EDIT_MAX_CHARS || val.includes('\n')) return true; const trimmed = val.trimStart(); - if (trimmed.startsWith('{') || trimmed.startsWith('[')) return true; - return false; + return trimmed.startsWith('{') || trimmed.startsWith('['); } - if (typeof val === 'object') { - return true; - } - return false; + return typeof val === 'object'; }; const getCellFieldName = (record: Item, dataIndex: string) => { @@ -323,7 +337,7 @@ const coerceJsonEditorValueForStorage = (currentValue: any, editedValue: any): a }; // --- Resizable Header (Native Implementation) --- -const ResizableTitle = (props: any) => { +const ResizableTitle = React.forwardRef((props, ref) => { const { onResizeStart, width, ...restProps } = props; const nextStyle = { ...(restProps.style || {}) } as React.CSSProperties; @@ -334,11 +348,11 @@ const ResizableTitle = (props: any) => { // 注意:virtual table 模式下,rc-table 会依赖 header cell 的 width 样式来渲染选择列。 // 若这里丢失 width,可能导致左上角“全选”checkbox 不显示。 if (!width || typeof onResizeStart !== 'function') { - return ; + return ; } return ( - + {restProps.children} { /> ); -}; +}); + +// --- Sortable Header Cell --- +interface SortableHeaderCellProps extends React.HTMLAttributes { + id?: string; +} + +// --- Sortable Header Cell --- +interface SortableHeaderCellProps extends React.HTMLAttributes { + id?: string; +} + +// 静态 CSS 移到组件外,强制去除 th 内边距并确保指针穿透 +const sortableHeaderStaticStyles = ` + .gonavi-sortable-header-cell { + padding: 0 !important; + } + .gonavi-sortable-header-cell[data-cursor-grabbing="true"], + .gonavi-sortable-header-cell[data-cursor-grabbing="true"] *, + .gonavi-sortable-header-cell.is-dragging, + .gonavi-sortable-header-cell.is-dragging * { + cursor: grabbing !important; + } + .sortable-header-cell-drag-handle { + display: flex; + align-items: center; + width: 100%; + height: 100%; + min-height: 44px; + padding: 0 10px; + user-select: none; + cursor: inherit; + } +`; + +const SortableHeaderCell: React.FC = React.memo((props) => { + const { id, children, style: propStyle, className: propClassName, ...restProps } = props; + const [isPressed, setIsPressed] = useState(false); + const { + attributes, + listeners, + setNodeRef, + transform, + transition, + isDragging, + } = useSortable({ id: id || '' }); + + const style: React.CSSProperties = { + ...propStyle, + transform: CSS.Transform.toString(transform), + transition, + ...(isDragging ? { + position: 'relative', + zIndex: 9999, + opacity: 0.6, + backgroundColor: 'rgba(24, 144, 255, 0.15)', + boxShadow: '0 4px 12px rgba(0,0,0,0.15)' + } : {}), + touchAction: 'none', + willChange: 'transform', + // 核心修复:将指针直接绑定到 th 级别,并由 isPressed 控制 + cursor: (isDragging || isPressed) ? 'grabbing' : 'pointer', + }; + + useEffect(() => { + const handleGlobalMouseUp = () => setIsPressed(false); + window.addEventListener('mouseup', handleGlobalMouseUp); + return () => window.removeEventListener('mouseup', handleGlobalMouseUp); + }, []); + + if (!id || id === 'GONAVI_SELECTION_COLUMN') { + return {children}; + } + + return ( + { + setIsPressed(true); + if (listeners?.onPointerDown) listeners.onPointerDown(e); + }} + > + +
+
+ {children} +
+
+
+ ); +}); // --- Contexts --- const EditableContext = React.createContext(null); @@ -375,7 +485,7 @@ const DataContext = React.createContext<{ handleCopyInsert: (r: any) => void; handleCopyJson: (r: any) => void; handleCopyCsv: (r: any) => void; - handleExportSelected: (format: string, r: any) => void; + handleExportSelected: (format: string, r: any) => Promise; copyToClipboard: (t: string) => void; tableName?: string; enableRowContextMenu: boolean; @@ -562,11 +672,11 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => { label: '导出选中数据', icon: , children: [ - { key: 'exp-csv', label: 'CSV', onClick: () => handleExportSelected('csv', record) }, - { key: 'exp-xlsx', label: 'Excel', onClick: () => handleExportSelected('xlsx', record) }, - { key: 'exp-json', label: 'JSON', onClick: () => handleExportSelected('json', record) }, - { key: 'exp-md', label: 'Markdown', onClick: () => handleExportSelected('md', record) }, - { key: 'exp-html', label: 'HTML', onClick: () => handleExportSelected('html', record) }, + { key: 'exp-csv', label: 'CSV', onClick: () => handleExportSelected('csv', record).catch(console.error) }, + { key: 'exp-xlsx', label: 'Excel', onClick: () => handleExportSelected('xlsx', record).catch(console.error) }, + { key: 'exp-json', label: 'JSON', onClick: () => handleExportSelected('json', record).catch(console.error) }, + { key: 'exp-md', label: 'Markdown', onClick: () => handleExportSelected('md', record).catch(console.error) }, + { key: 'exp-html', label: 'HTML', onClick: () => handleExportSelected('html', record).catch(console.error) }, ] } ]; @@ -640,13 +750,132 @@ const DataGrid: React.FC = ({ const appearance = useStore(state => state.appearance); const queryOptions = useStore(state => state.queryOptions); const setQueryOptions = useStore(state => state.setQueryOptions); + const tableColumnOrders = useStore(state => state.tableColumnOrders); + const enableColumnOrderMemory = useStore(state => state.enableColumnOrderMemory); + const setTableColumnOrder = useStore(state => state.setTableColumnOrder); + const setEnableColumnOrderMemory = useStore(state => state.setEnableColumnOrderMemory); + const clearTableColumnOrder = useStore(state => state.clearTableColumnOrder); + + const tableHiddenColumns = useStore(state => state.tableHiddenColumns); + const enableHiddenColumnMemory = useStore(state => state.enableHiddenColumnMemory); + const setTableHiddenColumns = useStore(state => state.setTableHiddenColumns); + const setEnableHiddenColumnMemory = useStore(state => state.setEnableHiddenColumnMemory); + const clearTableHiddenColumns = useStore(state => state.clearTableHiddenColumns); + const isMacLike = useMemo(() => isMacLikePlatform(), []); const darkMode = theme === 'dark'; const resolvedAppearance = resolveAppearanceValues(appearance); const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity); const canModifyData = !readOnly && !!tableName; - const showColumnComment = queryOptions?.showColumnComment !== false; - const showColumnType = queryOptions?.showColumnType !== false; + const showColumnComment = queryOptions?.showColumnComment ?? true; + const showColumnType = queryOptions?.showColumnType ?? true; + + // --- Display Columns Order & Visibility Management --- + const [allOrderedColumnNames, setAllOrderedColumnNames] = useState([]); + const [displayColumnNames, setDisplayColumnNames] = useState([]); + const [localHiddenColumns, setLocalHiddenColumns] = useState([]); + const [columnSearchText, setColumnSearchText] = useState(''); + + // Sync hidden columns from store + useEffect(() => { + if (enableHiddenColumnMemory && connectionId && dbName && tableName) { + const storedHidden = tableHiddenColumns[`${connectionId}-${dbName}-${tableName}`]; + setLocalHiddenColumns(Array.isArray(storedHidden) ? storedHidden : []); + } else { + setLocalHiddenColumns([]); + } + }, [tableHiddenColumns, enableHiddenColumnMemory, connectionId, dbName, tableName]); + + const toggleColumnVisibility = useCallback((col: string, visible: boolean) => { + setLocalHiddenColumns(prev => { + const nextSet = new Set(prev); + if (visible) nextSet.delete(col); + else nextSet.add(col); + const nextArray = Array.from(nextSet); + if (enableHiddenColumnMemory && connectionId && dbName && tableName) { + setTableHiddenColumns(connectionId, dbName, tableName, nextArray); + } + return nextArray; + }); + }, [enableHiddenColumnMemory, connectionId, dbName, tableName, setTableHiddenColumns]); + + const toggleAllColumnsVisibility = useCallback((visible: boolean) => { + setLocalHiddenColumns(() => { + const nextArray = visible ? [] : [...allOrderedColumnNames]; + if (enableHiddenColumnMemory && connectionId && dbName && tableName) { + setTableHiddenColumns(connectionId, dbName, tableName, nextArray); + } + return nextArray; + }); + }, [allOrderedColumnNames, enableHiddenColumnMemory, connectionId, dbName, tableName, setTableHiddenColumns]); + + // Sync display order from incoming prop and store memory + useEffect(() => { + let nextOrder = [...columnNames]; + if (enableColumnOrderMemory && connectionId && dbName && tableName) { + const storedOrder = tableColumnOrders[`${connectionId}-${dbName}-${tableName}`]; + if (Array.isArray(storedOrder) && storedOrder.length > 0) { + // Only layout known columns. Filter out missing or new columns. + const storedSet = new Set(storedOrder); + const incomingSet = new Set(nextOrder); + const validStored = storedOrder.filter(col => incomingSet.has(col)); + const missingNew = nextOrder.filter(col => !storedSet.has(col)); + nextOrder = [...validStored, ...missingNew]; + } + } + setAllOrderedColumnNames(nextOrder); + }, [columnNames, tableColumnOrders, enableColumnOrderMemory, connectionId, dbName, tableName]); + + // Compute final display columns + useEffect(() => { + const hiddenSet = new Set(localHiddenColumns); + setDisplayColumnNames(allOrderedColumnNames.filter(col => !hiddenSet.has(col))); + }, [allOrderedColumnNames, localHiddenColumns]); + + // Handle Dragging + const sensors = useSensors( + useSensor(PointerSensor, { activationConstraint: { distance: 8 } }), + useSensor(MouseSensor, { activationConstraint: { distance: 8 } }), + useSensor(TouchSensor, { activationConstraint: { delay: 200, tolerance: 5 } }), + ); + + const handleDragEnd = (event: DragEndEvent) => { + const { active, over } = event; + if (active.id !== over?.id && over) { + setAllOrderedColumnNames((prevAllOrder) => { + // Calculate the new order of all columns by applying the movement + // We only move the visible columns relative to each other, but the easiest way + // is to map the visible column movement back to the full array. + const hiddenSet = new Set(localHiddenColumns); + const visibleOrder = prevAllOrder.filter(col => !hiddenSet.has(col)); + + const oldVisibleIndex = visibleOrder.indexOf(active.id as string); + const newVisibleIndex = visibleOrder.indexOf(over.id as string); + + if (oldVisibleIndex === -1 || newVisibleIndex === -1) return prevAllOrder; + + const nextVisibleOrder = arrayMove(visibleOrder, oldVisibleIndex, newVisibleIndex); + + // Reconstruct allOrderedColumnNames by inserting hidden columns back to their original relative positions + // Or simpler: just keep hidden columns at the end, but that ruins user's layout. + // Better approach: build a new array + let vIndex = 0; + const nextOrder = prevAllOrder.map(col => { + if (hiddenSet.has(col)) { + return col; // Hidden columns stay at their absolute index in the master list + } else { + return nextVisibleOrder[vIndex++]; + } + }); + + if (enableColumnOrderMemory && connectionId && dbName && tableName) { + setTableColumnOrder(connectionId, dbName, tableName, nextOrder); + } + return nextOrder; + }); + } + }; + const selectionColumnWidth = 46; const currentConnConfig = connections.find(c => c.id === connectionId)?.config; const dataSourceCaps = getDataSourceCapabilities(currentConnConfig); @@ -689,14 +918,10 @@ const DataGrid: React.FC = ({ const panelPaddingX = 12; const toolbarBottomPadding = 6; const filterTopPadding = 2; - const panelBorderColor = darkMode ? 'rgba(255, 255, 255, 0.08)' : 'rgba(0, 0, 0, 0.08)'; const panelFrameColor = darkMode ? 'rgba(0, 0, 0, 0.42)' : 'rgba(0, 0, 0, 0.18)'; const floatingScrollbarGap = 6; const floatingScrollbarInset = 10; const floatingScrollbarHeight = 10; - const floatingScrollbarTrackBg = 'transparent'; - const floatingScrollbarBorderColor = 'transparent'; - const floatingScrollbarShadow = 'none'; const floatingScrollbarThumbBg = darkMode ? 'rgba(255,255,255,0.34)' : 'rgba(0,0,0,0.22)'; const floatingScrollbarThumbBorderColor = darkMode ? 'rgba(255,255,255,0.10)' : 'rgba(255,255,255,0.32)'; const floatingScrollbarThumbShadow = darkMode ? '0 4px 12px rgba(0,0,0,0.28)' : '0 4px 10px rgba(0,0,0,0.12)'; @@ -740,7 +965,7 @@ const DataGrid: React.FC = ({ const [form] = Form.useForm(); const [modal, contextHolder] = Modal.useModal(); - const gridId = useMemo(() => `grid-${uuidv4()}`, []); + const gridId = useMemo(() => `grid-${generateUuid()}`, []); const [viewMode, setViewMode] = useState('table'); const [textRecordIndex, setTextRecordIndex] = useState(0); const [cellEditorOpen, setCellEditorOpen] = useState(false); @@ -776,7 +1001,7 @@ const DataGrid: React.FC = ({ const containerRef = useRef(null); const tableContainerRef = useRef(null); const tableScrollTargetsRef = useRef([]); - const externalHScrollRef = useRef(null); + const externalHorizontalScrollRef = useRef(null); const horizontalSyncSourceRef = useRef<'table' | 'external' | ''>(''); const lastTableScrollLeftRef = useRef(0); const lastExternalScrollLeftRef = useRef(0); @@ -837,7 +1062,7 @@ const DataGrid: React.FC = ({ const showCellContextMenu = useCallback((e: React.MouseEvent, record: Item, dataIndex: string, title: React.ReactNode) => { e.preventDefault(); e.stopPropagation(); - const titleText = typeof title === 'string' ? title : (typeof title === 'number' ? String(title) : String(dataIndex)); + const titleText = typeof (title as any) === 'string' ? (title as string) : (typeof (title as any) === 'number' ? String(title) : String(dataIndex)); setCellContextMenu({ visible: true, x: e.clientX, @@ -854,14 +1079,14 @@ const DataGrid: React.FC = ({ try { const cleanRows = rows.map(({ [GONAVI_ROW_KEY]: _rowKey, ...rest }) => rest); // Pass tableName (or 'export') as default filename - const res = await ExportData(cleanRows, columnNames, tableName || 'export', format); + const res = await ExportData(cleanRows, displayColumnNames, tableName || 'export', format); if (res.success) { - message.success("导出成功"); + void message.success("导出成功"); } else if (res.message !== "Cancelled") { - message.error("导出失败: " + res.message); + void message.error("导出失败: " + res.message); } } catch (e: any) { - message.error("导出失败: " + (e?.message || String(e))); + void message.error("导出失败: " + (e?.message || String(e))); } finally { hide(); } @@ -1054,7 +1279,7 @@ const DataGrid: React.FC = ({ const raw = record?.[dataIndex]; const text = toEditableText(raw); const isJson = looksLikeJsonText(text); - const titleText = typeof title === 'string' ? title : (typeof title === 'number' ? String(title) : String(dataIndex)); + const titleText = typeof (title as any) === 'string' ? (title as string) : (typeof (title as any) === 'number' ? String(title) : String(dataIndex)); setCellEditorMeta({ record, dataIndex, title: titleText }); setCellEditorValue(text); @@ -1142,13 +1367,13 @@ const DataGrid: React.FC = ({ id: nextId, enabled: cond?.enabled !== false, logic: normalizeFilterLogic(cond?.logic), - column: rawColumn || (op === 'CUSTOM' ? '' : String(columnNames[0] || '')), + column: rawColumn || (op === 'CUSTOM' ? '' : String(displayColumnNames[0] || '')), op, value: String(cond?.value ?? ''), value2: String(cond?.value2 ?? ''), }; }); - }, [columnNames, normalizeFilterLogic]); + }, [displayColumnNames, normalizeFilterLogic]); // Filter State const [filterConditions, setFilterConditions] = useState([]); @@ -1196,9 +1421,9 @@ const DataGrid: React.FC = ({ const columnIndexMap = useMemo(() => { const map = new Map(); - columnNames.forEach((name, idx) => map.set(name, idx)); + displayColumnNames.forEach((name: string, idx: number) => map.set(name, idx)); return map; - }, [columnNames]); + }, [displayColumnNames]); // 直接操作 DOM 更新选中效果,避免 React 重渲染 const updateCellSelection = useCallback((newSelection: Set) => { @@ -1225,7 +1450,7 @@ const DataGrid: React.FC = ({ const handleBatchFillCells = useCallback(() => { const cellsToFill = currentSelectionRef.current; if (cellsToFill.size === 0) { - message.info('请先选择要填充的单元格'); + void message.info('请先选择要填充的单元格'); return; } @@ -1255,7 +1480,7 @@ const DataGrid: React.FC = ({ const existing = modifiedRows[rowKey]; const baseRow = baseRowMap.get(rowKey); - let currentVal: any = undefined; + let currentVal: any; const addedRow = addedRowMap.get(rowKey); if (addedRow) { @@ -1278,7 +1503,7 @@ const DataGrid: React.FC = ({ }); if (updatedCount === 0) { - message.info('选中的单元格无需更新'); + void message.info('选中的单元格无需更新'); return; } @@ -1306,7 +1531,7 @@ const DataGrid: React.FC = ({ return next || prev; }); - message.success(`已填充 ${updatedCount} 个单元格`); + void message.success(`已填充 ${updatedCount} 个单元格`); setBatchEditModalOpen(false); // 清除选中状态 @@ -1377,7 +1602,7 @@ const DataGrid: React.FC = ({ const row = currentData[i]; const rKey = String(row?.[GONAVI_ROW_KEY]); for (let j = minColIndex; j <= maxColIndex; j++) { - newSelectedCells.add(makeCellKey(rKey, columnNames[j])); + newSelectedCells.add(makeCellKey(rKey, displayColumnNames[j])); } } @@ -1548,7 +1773,7 @@ const DataGrid: React.FC = ({ cellSelectionPointerRef.current = null; isDraggingRef.current = false; }; - }, [cellEditMode, columnNames, columnIndexMap, updateCellSelection]); + }, [cellEditMode, displayColumnNames, columnIndexMap, updateCellSelection]); // 批量填充到选中行 const handleBatchFillToSelected = useCallback((sourceRecord: Item, dataIndex: string) => { @@ -1556,7 +1781,7 @@ const DataGrid: React.FC = ({ const selKeys = selectedRowKeysRef.current; if (selKeys.length === 0) { - message.info('请先选择要填充的行'); + void message.info('请先选择要填充的行'); return; } @@ -1565,7 +1790,7 @@ const DataGrid: React.FC = ({ const targetKeys = selKeys.filter(k => k !== sourceKey); if (targetKeys.length === 0) { - message.info('没有其他选中的行可以填充'); + void message.info('没有其他选中的行可以填充'); return; } @@ -1604,7 +1829,7 @@ const DataGrid: React.FC = ({ return next || prev; }); - message.success(`已填充 ${updatedCount} 行`); + void message.success(`已填充 ${updatedCount} 行`); setCellContextMenu(prev => ({ ...prev, visible: false })); }, [addedRows, rowKeyStr]); @@ -1639,7 +1864,7 @@ const DataGrid: React.FC = ({ return ''; }, [addedRowKeySet, modifiedRowKeySet, deletedRowKeys, rowKeyStr]); - const handleTableChange = useCallback((pag: any, filtersArg: any, sorter: any) => { + const handleTableChange = useCallback((_pag: any, _filtersArg: any, sorter: any) => { if (isResizingRef.current) return; // Block sort if resizing if (sorter.field) { const field = String(sorter.field); @@ -1809,7 +2034,7 @@ const DataGrid: React.FC = ({ const obj = JSON.parse(cellEditorValue); setCellEditorValue(JSON.stringify(obj, null, 2)); } catch (e: any) { - message.error("JSON 格式无效:" + (e?.message || String(e))); + void message.error("JSON 格式无效:" + (e?.message || String(e))); } }, [cellEditorIsJson, cellEditorValue]); @@ -1887,12 +2112,12 @@ const DataGrid: React.FC = ({ const openRowEditorByKey = useCallback((keyStr?: string) => { if (!canModifyData) return; if (!keyStr) { - message.info('请先定位到要编辑的记录'); + void message.info('请先定位到要编辑的记录'); return; } const displayRow = mergedDisplayData.find(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr); if (!displayRow) { - message.error('未找到目标行,请刷新后重试'); + void message.error('未找到目标行,请刷新后重试'); return; } @@ -1922,17 +2147,17 @@ const DataGrid: React.FC = ({ rowEditorForm.setFieldsValue(formMap); setRowEditorRowKey(keyStr); setRowEditorOpen(true); - }, [canModifyData, mergedDisplayData, data, addedRows, columnNames, rowEditorForm, rowKeyStr]); + }, [canModifyData, mergedDisplayData, data, addedRows, displayColumnNames, rowEditorForm, rowKeyStr]); const openRowEditor = useCallback(() => { if (!canModifyData) return; if (selectedRowKeys.length > 1) { - message.info('一次只能编辑一行,请仅选择一行'); + void message.info('一次只能编辑一行,请仅选择一行'); return; } const keyStr = selectedRowKeys.length === 1 ? rowKeyStr(selectedRowKeys[0]) : undefined; if (!keyStr) { - message.info('请先选择一行(勾选复选框)'); + void message.info('请先选择一行(勾选复选框)'); return; } openRowEditorByKey(keyStr); @@ -1943,7 +2168,7 @@ const DataGrid: React.FC = ({ const currentRow = mergedDisplayData[textRecordIndex]; const rowKey = currentRow?.[GONAVI_ROW_KEY]; if (rowKey === undefined || rowKey === null) { - message.info('当前记录不可编辑'); + void message.info('当前记录不可编辑'); return; } openRowEditorByKey(rowKeyStr(rowKey)); @@ -1960,7 +2185,7 @@ const DataGrid: React.FC = ({ const parsed = JSON.parse(jsonEditorValue); setJsonEditorValue(JSON.stringify(parsed, null, 2)); } catch (e: any) { - message.error("JSON 格式无效:" + (e?.message || String(e))); + void message.error("JSON 格式无效:" + (e?.message || String(e))); } }, [jsonEditorValue]); @@ -1970,16 +2195,16 @@ const DataGrid: React.FC = ({ try { parsed = JSON.parse(jsonEditorValue); } catch (e: any) { - message.error("JSON 解析失败:" + (e?.message || String(e))); + void message.error("JSON 解析失败:" + (e?.message || String(e))); return; } if (!Array.isArray(parsed)) { - message.error("JSON 视图必须是数组格式(每项对应一条记录)"); + void message.error("JSON 视图必须是数组格式(每项对应一条记录)"); return; } if (parsed.length !== mergedDisplayData.length) { - message.error(`记录条数不一致:当前 ${mergedDisplayData.length} 条,JSON 中 ${parsed.length} 条。请勿在此模式增删记录。`); + void message.error(`记录条数不一致:当前 ${mergedDisplayData.length} 条,JSON 中 ${parsed.length} 条。请勿在此模式增删记录。`); return; } @@ -2003,14 +2228,14 @@ const DataGrid: React.FC = ({ for (let idx = 0; idx < parsed.length; idx += 1) { const nextItem = parsed[idx]; if (!isPlainObject(nextItem)) { - message.error(`第 ${idx + 1} 条记录不是对象,无法应用`); + void message.error(`第 ${idx + 1} 条记录不是对象,无法应用`); return; } const currentRow = mergedDisplayData[idx]; const rowKey = currentRow?.[GONAVI_ROW_KEY]; if (rowKey === undefined || rowKey === null) { - message.error(`第 ${idx + 1} 条记录缺少行标识,无法应用`); + void message.error(`第 ${idx + 1} 条记录缺少行标识,无法应用`); return; } const keyStr = rowKeyStr(rowKey); @@ -2061,8 +2286,8 @@ const DataGrid: React.FC = ({ }); setJsonEditorOpen(false); - message.success("JSON 修改已应用到当前结果集,可继续“提交事务”"); - }, [canModifyData, jsonEditorValue, mergedDisplayData, addedRows, rowKeyStr, data, columnNames]); + void message.success("JSON 修改已应用到当前结果集,可继续“提交事务”"); + }, [canModifyData, jsonEditorValue, mergedDisplayData, addedRows, rowKeyStr, data, displayColumnNames]); const openRowEditorFieldEditor = useCallback((dataIndex: string) => { if (!dataIndex) return; @@ -2089,7 +2314,7 @@ const DataGrid: React.FC = ({ const baseRawMap = rowEditorBaseRawRef.current || {}; const patch: Record = {}; - columnNames.forEach((col) => { + displayColumnNames.forEach((col) => { const nextVal = values[col]; const baseVal = baseRawMap[col]; if (!isCellValueEqualForDiff(baseVal, nextVal)) patch[col] = nextVal; @@ -2103,14 +2328,14 @@ const DataGrid: React.FC = ({ }); closeRowEditor(); - }, [rowEditorRowKey, rowEditorForm, addedRows, columnNames, rowKeyStr, closeRowEditor]); + }, [rowEditorRowKey, rowEditorForm, addedRows, displayColumnNames, rowKeyStr, closeRowEditor]); const enableVirtual = viewMode === 'table'; const enableInlineEditableCell = canModifyData; - const columns = useMemo(() => { - return columnNames.map(key => ({ + const columns: (ColumnType & { editable?: boolean })[] = useMemo(() => { + return displayColumnNames.map(key => ({ title: renderColumnTitle(key), dataIndex: key, key: key, @@ -2130,7 +2355,9 @@ const DataGrid: React.FC = ({ return !isCellValueEqualForRender(record?.[key], prevRecord?.[key]); }, onHeaderCell: (column: any) => ({ + id: key, width: column.width, + className: 'gonavi-sortable-header-cell', onResizeStart: handleResizeStart(key), // Only need start onClickCapture: (event: React.MouseEvent) => { if (!onSort) return; @@ -2154,10 +2381,10 @@ const DataGrid: React.FC = ({ }, }), })); - }, [columnNames, columnWidths, sortInfo, handleResizeStart, canModifyData, onSort, renderColumnTitle]); + }, [displayColumnNames, columnWidths, sortInfo, handleResizeStart, canModifyData, onSort, renderColumnTitle]); - const mergedColumns = useMemo(() => columns.map(col => { - if (!col.editable) return col; + const mergedColumns = useMemo(() => columns.map((col): ColumnType => { + if (!col.editable) return col as ColumnType; const dataIndex = String(col.dataIndex); return { ...col, @@ -2191,7 +2418,7 @@ const DataGrid: React.FC = ({ return ( = ({ pendingScrollToBottomRef.current = true; setAddedRows(prev => [...prev, newRow]); }; - const handleDeleteSelected = () => { setDeletedRowKeys(prev => { const newDeleted = new Set(prev); @@ -2284,7 +2510,7 @@ const DataGrid: React.FC = ({ if (!hasRowKey) { values = { ...(newRow as any) }; } else { - columnNames.forEach((col) => { + displayColumnNames.forEach((col) => { const nextVal = (newRow as any)?.[col]; const prevVal = (originalRow as any)?.[col]; if (!isCellValueEqualForDiff(prevVal, nextVal)) values[col] = nextVal; @@ -2304,7 +2530,7 @@ const DataGrid: React.FC = ({ }); if (inserts.length === 0 && updates.length === 0 && deletes.length === 0) { - message.info("No changes to commit"); + void message.info("No changes to commit"); return; } @@ -2337,7 +2563,7 @@ const DataGrid: React.FC = ({ message: res.message, dbName }); - message.success("事务提交成功"); + void message.success("事务提交成功"); setAddedRows([]); setModifiedRows({}); setDeletedRowKeys(new Set()); @@ -2352,13 +2578,13 @@ const DataGrid: React.FC = ({ message: res.message, dbName }); - message.error("提交失败: " + res.message); + void message.error("提交失败: " + res.message); } }; const copyToClipboard = useCallback((text: string) => { - navigator.clipboard.writeText(text); - message.success("Copied to clipboard"); + navigator.clipboard.writeText(text).catch(console.error); + void message.success("Copied to clipboard"); }, []); const getTargets = useCallback((clickedRecord: any) => { @@ -2373,19 +2599,18 @@ const DataGrid: React.FC = ({ const handleCopyInsert = useCallback((record: any) => { if (!supportsCopyInsert) { - message.warning("当前数据源不支持复制为 INSERT,请使用 JSON/CSV/Markdown 复制。"); + void message.warning("当前数据源不支持复制为 INSERT,请使用 JSON/CSV/Markdown 复制。"); return; } const records = getTargets(record); - const sqls = records.map((r: any) => { + const sqlList = records.map((r: any) => { const { [GONAVI_ROW_KEY]: _rowKey, ...vals } = r; const cols = Object.keys(vals); - const values = Object.values(vals).map(v => v === null ? 'NULL' : `'${v}'`); + const values = Object.values(vals).map(v => v === null ? 'NULL' : `'${v}'`); const targetTable = tableName || 'table'; return `INSERT INTO \`${targetTable}\` (${cols.map(c => `\`${c}\``).join(', ')}) VALUES (${values.join(', ')});`; }); - copyToClipboard(sqls.join('\n')); - }, [supportsCopyInsert, tableName, getTargets, copyToClipboard]); + copyToClipboard(sqlList.join('\n')); }, [supportsCopyInsert, tableName, getTargets, copyToClipboard]); const handleCopyJson = useCallback((record: any) => { const records = getTargets(record); @@ -2427,12 +2652,12 @@ const DataGrid: React.FC = ({ try { const res = await ExportQuery(config as any, dbName || '', sql, defaultName || 'export', format); if (res.success) { - message.success("导出成功"); + void message.success("导出成功"); } else if (res.message !== "Cancelled") { - message.error("导出失败: " + res.message); + void message.error("导出失败: " + res.message); } } catch (e: any) { - message.error("导出失败: " + (e?.message || String(e))); + void message.error("导出失败: " + (e?.message || String(e))); } finally { hide(); } @@ -2489,7 +2714,7 @@ const DataGrid: React.FC = ({ // 有未提交修改时,优先按界面数据导出,避免与数据库不一致。 if (hasChanges) { - message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。"); + void message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。"); await exportData(records, format); return; } @@ -2545,12 +2770,12 @@ const DataGrid: React.FC = ({ try { const res = await ExportTable(config as any, dbName || '', tableName, format); if (res.success) { - message.success("导出成功"); + void message.success("导出成功"); } else if (res.message !== "Cancelled") { - message.error("导出失败: " + res.message); + void message.error("导出失败: " + res.message); } } catch (e: any) { - message.error("导出失败: " + (e?.message || String(e))); + void message.error("导出失败: " + (e?.message || String(e))); } finally { hide(); } @@ -2558,7 +2783,7 @@ const DataGrid: React.FC = ({ const handlePage = async () => { instance.destroy(); if (hasChanges) { - message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。"); + void message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。"); await exportData(displayData, format); return; } @@ -2599,15 +2824,15 @@ const DataGrid: React.FC = ({ const handleExportFilteredAll = async (format: string) => { if (!connectionId || !tableName) return; if (!filteredExportSql) { - message.warning('当前未应用筛选条件'); + void message.warning('当前未应用筛选条件'); return; } if (!supportsSqlQueryExport) { - message.error('当前数据源不支持按筛选结果导出'); + void message.error('当前数据源不支持按筛选结果导出'); return; } if (hasChanges) { - message.warning("当前存在未提交修改,筛选结果导出基于数据库已提交数据。"); + void message.warning("当前存在未提交修改,筛选结果导出基于数据库已提交数据。"); } await exportByQuery(filteredExportSql, format, `${tableName || 'export'}_filtered`); @@ -2623,14 +2848,14 @@ const DataGrid: React.FC = ({ setImportFilePath(res.data.filePath); setImportPreviewVisible(true); } else if (res.message !== "Cancelled") { - message.error("选择文件失败: " + res.message); + void message.error("选择文件失败: " + res.message); } }; const handleImportSuccess = () => { setImportPreviewVisible(false); setImportFilePath(''); - message.success('导入完成'); + void message.success('导入完成'); if (onReload) onReload(); }; @@ -2676,7 +2901,7 @@ const DataGrid: React.FC = ({ id: nextFilterId, enabled: true, logic: 'AND', - column: columnNames[0] || '', + column: displayColumnNames[0] || '', op: '=', value: '', value2: '', @@ -2734,19 +2959,93 @@ const DataGrid: React.FC = ({ ]; const columnInfoSettingContent = ( -
+
+
显示设置
setQueryOptions({ showColumnComment: e.target.checked })} > - 下方显示备注 + 表头显示备注 setQueryOptions({ showColumnType: e.target.checked })} > - 下方显示类型 + 表头显示类型 +
+ + + setColumnSearchText(e.target.value)} + allowClear + /> +
+ {allOrderedColumnNames.filter(col => !columnSearchText || col.toLowerCase().includes(columnSearchText.toLowerCase())).map(col => ( + toggleColumnVisibility(col, e.target.checked)} + style={{ marginLeft: 0 }} + > + {col} + + ))} +
+ +
+ setEnableColumnOrderMemory(e.target.checked)} + > + 记忆自定义列序 + + setEnableHiddenColumnMemory(e.target.checked)} + > + 记忆隐藏列配置 + +
+ + +
); @@ -2776,7 +3075,7 @@ const DataGrid: React.FC = ({ const rowPropsFactory = useCallback((record: any) => ({ record } as any), []); - const totalWidth = columns.reduce((sum, col) => sum + (Number(col.width) || 200), 0) + selectionColumnWidth; + const totalWidth = columns.reduce((sum: number, col: any) => sum + (Number(col.width) || 200), 0) + selectionColumnWidth; const useContextMenuRow = false; const tableScrollX = useMemo(() => { const baseWidth = Math.max(totalWidth, 1000); @@ -2796,8 +3095,8 @@ const DataGrid: React.FC = ({ body.row = ContextMenuRow; } return Object.keys(body).length > 0 - ? { body, header: { cell: ResizableTitle } } - : { header: { cell: ResizableTitle } }; + ? { body, header: { cell: SortableHeaderCell } } + : { header: { cell: SortableHeaderCell } }; }, [enableInlineEditableCell, useContextMenuRow]); const tableOnRow = useMemo(() => (useContextMenuRow ? rowPropsFactory : undefined), [useContextMenuRow, rowPropsFactory]); @@ -2821,7 +3120,7 @@ const DataGrid: React.FC = ({ }, []); const syncExternalScrollFromTargets = useCallback((targets?: HTMLElement[], source?: HTMLElement | null) => { - const externalScroll = externalHScrollRef.current; + const externalScroll = externalHorizontalScrollRef.current; if (!(externalScroll instanceof HTMLDivElement) || horizontalSyncSourceRef.current === 'external') { return; } @@ -2845,7 +3144,7 @@ const DataGrid: React.FC = ({ }, []); const applyExternalScrollToTableTargets = useCallback(() => { - const externalScroll = externalHScrollRef.current; + const externalScroll = externalHorizontalScrollRef.current; if (!(externalScroll instanceof HTMLDivElement)) { return; } @@ -2878,7 +3177,7 @@ const DataGrid: React.FC = ({ // 非虚拟模式:外部水平滚动条的 wheel 处理(通过原生事件绑定,确保 preventDefault 生效) useEffect(() => { - const externalScroll = externalHScrollRef.current; + const externalScroll = externalHorizontalScrollRef.current; if (!externalScroll || !horizontalScrollVisible) return; const handleExternalWheel = (e: WheelEvent) => { @@ -2892,8 +3191,7 @@ const DataGrid: React.FC = ({ const maxScrollLeft = Math.max(0, externalScroll.scrollWidth - externalScroll.clientWidth); if (maxScrollLeft <= 0) return; - const nextScrollLeft = Math.max(0, Math.min(maxScrollLeft, externalScroll.scrollLeft + dominantDelta)); - externalScroll.scrollLeft = nextScrollLeft; + externalScroll.scrollLeft = Math.max(0, Math.min(maxScrollLeft, externalScroll.scrollLeft + dominantDelta)); }; externalScroll.addEventListener('wheel', handleExternalWheel, { passive: false, capture: true }); @@ -2922,7 +3220,7 @@ const DataGrid: React.FC = ({ const isTableDataAreaTarget = (target: EventTarget | null) => { const element = target instanceof HTMLElement ? target : null; if (!element) return false; - if (element.closest('.data-grid-external-hscroll')) return false; + if (element.closest('.data-grid-external-horizontal-scroll')) return false; return !!element.closest('.ant-table-body, .ant-table-content, .ant-table-cell, .ant-table-row, .ant-table-tbody'); }; @@ -2948,7 +3246,7 @@ const DataGrid: React.FC = ({ activeTarget.scrollLeft = nextScrollLeft; lastTableScrollLeftRef.current = nextScrollLeft; - const externalScroll = externalHScrollRef.current; + const externalScroll = externalHorizontalScrollRef.current; if (externalScroll && Math.abs(externalScroll.scrollLeft - nextScrollLeft) > 1) { externalScroll.scrollLeft = nextScrollLeft; lastExternalScrollLeftRef.current = nextScrollLeft; @@ -2976,7 +3274,7 @@ const DataGrid: React.FC = ({ let rafId: number | null = null; let boundVerticalTarget: HTMLElement | null = null; let boundHorizontalTargets: HTMLElement[] = []; - const externalScroll = externalHScrollRef.current; + const externalScroll = externalHorizontalScrollRef.current; const hasStoredScroll = !!scrollSnapshot && (Math.abs(scrollSnapshot.top) > 0.5 || Math.abs(scrollSnapshot.left) > 0.5); const emitSnapshot = () => { @@ -3043,7 +3341,7 @@ const DataGrid: React.FC = ({ target.scrollLeft = nextLeft; } }); - const externalScroll = externalHScrollRef.current; + const externalScroll = externalHorizontalScrollRef.current; if (externalScroll && Math.abs(externalScroll.scrollLeft - nextLeft) > 1) { externalScroll.scrollLeft = nextLeft; } @@ -3126,7 +3424,7 @@ const DataGrid: React.FC = ({ useEffect(() => { if (viewMode !== 'table') return; const tableContainer = tableContainerRef.current; - const externalScroll = externalHScrollRef.current; + const externalScroll = externalHorizontalScrollRef.current; if (!(tableContainer instanceof HTMLElement) || !(externalScroll instanceof HTMLDivElement)) return; let rafId: number | null = null; @@ -3270,7 +3568,7 @@ const DataGrid: React.FC = ({ } updateCellSelection(new Set()); if (!next) setBatchEditModalOpen(false); - message.info(next ? '已进入单元格编辑模式,可拖拽选择多个单元格' : '已退出单元格编辑模式'); + void message.info(next ? '已进入单元格编辑模式,可拖拽选择多个单元格' : '已退出单元格编辑模式').then(); }} > 单元格编辑器 @@ -3412,7 +3710,7 @@ const DataGrid: React.FC = ({ style={{ width: 180 }} value={cond.column} onChange={v => updateFilter(cond.id, 'column', v)} - options={columnNames.map(c => ({ value: c, label: c }))} + options={displayColumnNames.map(c => ({ value: c, label: c }))} showSearch optionFilterProp="label" filterOption={(input, option) => @@ -3508,7 +3806,7 @@ const DataGrid: React.FC = ({
- {columnNames.map((col) => { + {displayColumnNames.map((col: string) => { const sample = rowEditorDisplayRef.current?.[col] ?? ''; const placeholder = rowEditorNullColsRef.current?.has(col) ? '(NULL)' : undefined; const isJson = looksLikeJsonText(sample); @@ -3645,32 +3943,36 @@ const DataGrid: React.FC = ({ - + + +
+ +
= ({ }} >
@@ -3734,7 +4036,7 @@ const DataGrid: React.FC = ({ )}
- {currentTextRow ? columnNames.map((col) => ( + {currentTextRow ? displayColumnNames.map((col) => (
{col} : @@ -3885,7 +4187,7 @@ const DataGrid: React.FC = ({ onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} onClick={() => { - if (cellContextMenu.record) handleExportSelected('csv', cellContextMenu.record); + if (cellContextMenu.record) handleExportSelected('csv', cellContextMenu.record).catch(console.error); setCellContextMenu(prev => ({ ...prev, visible: false })); }} > @@ -3900,7 +4202,7 @@ const DataGrid: React.FC = ({ onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} onClick={() => { - if (cellContextMenu.record) handleExportSelected('xlsx', cellContextMenu.record); + if (cellContextMenu.record) handleExportSelected('xlsx', cellContextMenu.record).catch(console.error); setCellContextMenu(prev => ({ ...prev, visible: false })); }} > @@ -3915,7 +4217,7 @@ const DataGrid: React.FC = ({ onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} onClick={() => { - if (cellContextMenu.record) handleExportSelected('json', cellContextMenu.record); + if (cellContextMenu.record) handleExportSelected('json', cellContextMenu.record).catch(console.error); setCellContextMenu(prev => ({ ...prev, visible: false })); }} > @@ -3930,7 +4232,7 @@ const DataGrid: React.FC = ({ onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} onClick={() => { - if (cellContextMenu.record) handleExportSelected('html', cellContextMenu.record); + if (cellContextMenu.record) handleExportSelected('html', cellContextMenu.record).catch(console.error); setCellContextMenu(prev => ({ ...prev, visible: false })); }} > @@ -4124,7 +4426,7 @@ const DataGrid: React.FC = ({ border-radius: 999px; box-shadow: ${floatingScrollbarThumbShadow}; } - .${gridId} .data-grid-external-hscroll { + .${gridId} .data-grid-external-horizontal-scroll { position: absolute; left: ${floatingScrollbarInset}px; right: ${floatingScrollbarInset}px; @@ -4135,22 +4437,22 @@ const DataGrid: React.FC = ({ background: transparent; z-index: 24; } - .${gridId} .data-grid-external-hscroll::-webkit-scrollbar { + .${gridId} .data-grid-external-horizontal-scroll::-webkit-scrollbar { height: ${floatingScrollbarHeight}px; } - .${gridId} .data-grid-external-hscroll::-webkit-scrollbar-track { + .${gridId} .data-grid-external-horizontal-scroll::-webkit-scrollbar-track { background: ${horizontalScrollbarTrackBg}; border: 1px solid ${horizontalScrollbarTrackBorderColor}; border-radius: 999px; box-shadow: ${horizontalScrollbarTrackShadow}; } - .${gridId} .data-grid-external-hscroll::-webkit-scrollbar-thumb { + .${gridId} .data-grid-external-horizontal-scroll::-webkit-scrollbar-thumb { background: ${horizontalScrollbarThumbBg}; border: 1px solid ${horizontalScrollbarThumbBorderColor}; border-radius: 999px; box-shadow: ${horizontalScrollbarThumbShadow}; } - .${gridId} .data-grid-external-hscroll-inner { + .${gridId} .data-grid-external-horizontal-scroll-inner { height: 1px; } .${gridId} .data-grid-pagination-shell { diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 3a31be4..9fc732b 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -792,7 +792,20 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> case 'kingbase': case 'highgo': case 'vastbase': - return [{ sql: `SELECT n.nspname AS schema_name, p.proname AS routine_name, CASE WHEN p.prokind = 'p' THEN 'PROCEDURE' ELSE 'FUNCTION' END AS routine_type FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname NOT LIKE 'pg_%' ORDER BY n.nspname, routine_type, p.proname` }]; + return normalizeMetadataQuerySpecs([ + { + // PostgreSQL 11+ / 部分 PG-like:通过 prokind 区分 FUNCTION/PROCEDURE + sql: `SELECT n.nspname AS schema_name, p.proname AS routine_name, CASE WHEN p.prokind = 'p' THEN 'PROCEDURE' ELSE 'FUNCTION' END AS routine_type FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname NOT LIKE 'pg_%' ORDER BY n.nspname, routine_type, p.proname`, + }, + { + // PostgreSQL 10 / 不支持 prokind 的兼容路径 + sql: `SELECT r.routine_schema AS schema_name, r.routine_name AS routine_name, COALESCE(NULLIF(UPPER(r.routine_type), ''), 'FUNCTION') AS routine_type FROM information_schema.routines r WHERE r.routine_schema NOT IN ('pg_catalog', 'information_schema') AND r.routine_schema NOT LIKE 'pg_%' ORDER BY r.routine_schema, routine_type, r.routine_name`, + }, + { + // 最后兜底:仅函数列表,确保 prokind/routines 视图异常时仍可展示 + sql: `SELECT n.nspname AS schema_name, p.proname AS routine_name, 'FUNCTION' AS routine_type FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname NOT LIKE 'pg_%' ORDER BY n.nspname, p.proname`, + }, + ]); case 'sqlserver': { const safeDb = quoteSqlServerIdentifier(dbName || 'master'); return [{ sql: `SELECT s.name AS schema_name, o.name AS routine_name, CASE o.type WHEN 'P' THEN 'PROCEDURE' WHEN 'FN' THEN 'FUNCTION' WHEN 'IF' THEN 'FUNCTION' WHEN 'TF' THEN 'FUNCTION' END AS routine_type FROM ${safeDb}.sys.objects o JOIN ${safeDb}.sys.schemas s ON o.schema_id = s.schema_id WHERE o.type IN ('P','FN','IF','TF') ORDER BY o.type, s.name, o.name` }]; diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index 9457771..7ab4fee 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -9,6 +9,36 @@ import { loader } from '@monaco-editor/react' import * as monaco from 'monaco-editor' loader.config({ monaco }) +if (typeof window !== 'undefined' && !(window as any).go) { + (window as any).go = { + app: { + App: { + CheckUpdate: async () => ({ success: false }), + DownloadUpdate: async () => ({ success: false }), + GetSavedConnections: async () => [], + SaveConnection: async () => null, + DeleteConnection: async () => null, + OpenConnection: async () => null, + CloseConnection: async () => null, + GetDatabases: async () => [], + GetTables: async () => [], + GetTableData: async () => ({ columns: [], rows: [], total: 0 }), + GetTableColumns: async () => [], + ExecuteQuery: async () => ({ columns: [], rows: [], time: 0 }), + GetSavedQueries: async () => [], + SaveQuery: async () => null, + DeleteQuery: async () => null, + GetAppInfo: async () => ({}), + CheckForUpdates: async () => ({ success: false }), + OpenDownloadedUpdateDirectory: async () => ({ success: false }), + InstallUpdateAndRestart: async () => ({ success: false }), + ImportConfigFile: async () => ({ success: false }), + ExportData: async () => ({ success: false }), + } + } + }; +} + // 全局注册透明主题,避免每个 Editor 组件 beforeMount 中重复定义 monaco.editor.defineTheme('transparent-dark', { base: 'vs-dark', inherit: true, rules: [], diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 8d67849..172099d 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -416,6 +416,10 @@ interface AppState { sqlLogs: SqlLog[]; tableAccessCount: Record; tableSortPreference: Record; + tableColumnOrders: Record; + enableColumnOrderMemory: boolean; + tableHiddenColumns: Record; + enableHiddenColumnMemory: boolean; addConnection: (conn: SavedConnection) => void; updateConnection: (conn: SavedConnection) => void; @@ -458,6 +462,13 @@ interface AppState { recordTableAccess: (connectionId: string, dbName: string, tableName: string) => void; setTableSortPreference: (connectionId: string, dbName: string, sortBy: 'name' | 'frequency') => void; + setTableColumnOrder: (connectionId: string, dbName: string, tableName: string, order: string[]) => void; + setEnableColumnOrderMemory: (enabled: boolean) => void; + clearTableColumnOrder: (connectionId: string, dbName: string, tableName: string) => void; + + setTableHiddenColumns: (connectionId: string, dbName: string, tableName: string, hiddenColumns: string[]) => void; + setEnableHiddenColumnMemory: (enabled: boolean) => void; + clearTableHiddenColumns: (connectionId: string, dbName: string, tableName: string) => void; } const sanitizeSavedQueries = (value: unknown): SavedQuery[] => { @@ -521,6 +532,28 @@ const sanitizeTableSortPreference = (value: unknown): Record => { + const raw = (value && typeof value === 'object') ? value as Record : {}; + const result: Record = {}; + Object.entries(raw).forEach(([key, orderArray]) => { + if (Array.isArray(orderArray)) { + result[key] = orderArray.map(col => String(col)); + } + }); + return result; +}; + +const sanitizeTableHiddenColumns = (value: unknown): Record => { + const raw = (value && typeof value === 'object') ? value as Record : {}; + const result: Record = {}; + Object.entries(raw).forEach(([key, hiddenArray]) => { + if (Array.isArray(hiddenArray)) { + result[key] = hiddenArray.map(col => String(col)); + } + }); + return result; +}; + const sanitizeAppearance = ( appearance: Partial<{ enabled: boolean; opacity: number; blur: number }> | undefined, version: number @@ -598,6 +631,10 @@ export const useStore = create()( sqlLogs: [], tableAccessCount: {}, tableSortPreference: {}, + tableColumnOrders: {}, + enableColumnOrderMemory: true, + tableHiddenColumns: {}, + enableHiddenColumnMemory: true, addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })), updateConnection: (conn) => set((state) => ({ @@ -800,6 +837,44 @@ export const useStore = create()( } }; }), + + setTableColumnOrder: (connectionId, dbName, tableName, order) => set((state) => { + const key = `${connectionId}-${dbName}-${tableName}`; + return { + tableColumnOrders: { + ...state.tableColumnOrders, + [key]: order + } + }; + }), + + clearTableColumnOrder: (connectionId, dbName, tableName) => set((state) => { + const key = `${connectionId}-${dbName}-${tableName}`; + const newOrders = { ...state.tableColumnOrders }; + delete newOrders[key]; + return { tableColumnOrders: newOrders }; + }), + + setEnableColumnOrderMemory: (enabled) => set({ enableColumnOrderMemory: !!enabled }), + + setTableHiddenColumns: (connectionId, dbName, tableName, hiddenColumns) => set((state) => { + const key = `${connectionId}-${dbName}-${tableName}`; + return { + tableHiddenColumns: { + ...state.tableHiddenColumns, + [key]: hiddenColumns + } + }; + }), + + clearTableHiddenColumns: (connectionId, dbName, tableName) => set((state) => { + const key = `${connectionId}-${dbName}-${tableName}`; + const newHidden = { ...state.tableHiddenColumns }; + delete newHidden[key]; + return { tableHiddenColumns: newHidden }; + }), + + setEnableHiddenColumnMemory: (enabled) => set({ enableHiddenColumnMemory: !!enabled }), }), { name: 'lite-db-storage', // name of the item in the storage (must be unique) @@ -825,6 +900,13 @@ export const useStore = create()( nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions); nextState.tableAccessCount = sanitizeTableAccessCount(state.tableAccessCount); nextState.tableSortPreference = sanitizeTableSortPreference(state.tableSortPreference); + // 新增的列排序记忆状态不需要做版本特殊兼容,直接做基本的类型保护 + const safeOrders = sanitizeTableColumnOrders(state.tableColumnOrders); + nextState.tableColumnOrders = safeOrders; + nextState.enableColumnOrderMemory = state.enableColumnOrderMemory !== false; + const safeHidden = sanitizeTableHiddenColumns(state.tableHiddenColumns); + nextState.tableHiddenColumns = safeHidden; + nextState.enableHiddenColumnMemory = state.enableHiddenColumnMemory !== false; return nextState as AppState; }, merge: (persistedState, currentState) => { @@ -841,11 +923,16 @@ export const useStore = create()( fontSize: sanitizeFontSize(state.fontSize), startupFullscreen: sanitizeStartupFullscreen(state.startupFullscreen), globalProxy: sanitizeGlobalProxy(state.globalProxy), + tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference), + tableColumnOrders: sanitizeTableColumnOrders(state.tableColumnOrders), + enableColumnOrderMemory: state.enableColumnOrderMemory !== false, + tableHiddenColumns: sanitizeTableHiddenColumns(state.tableHiddenColumns), + enableHiddenColumnMemory: state.enableHiddenColumnMemory !== false, + sqlFormatOptions: sanitizeSqlFormatOptions(state.sqlFormatOptions), queryOptions: sanitizeQueryOptions(state.queryOptions), shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions), tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount), - tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference), }; }, partialize: (state) => ({ @@ -862,7 +949,11 @@ export const useStore = create()( queryOptions: state.queryOptions, shortcutOptions: state.shortcutOptions, tableAccessCount: state.tableAccessCount, - tableSortPreference: state.tableSortPreference + tableSortPreference: state.tableSortPreference, + tableColumnOrders: state.tableColumnOrders, + enableColumnOrderMemory: state.enableColumnOrderMemory, + tableHiddenColumns: state.tableHiddenColumns, + enableHiddenColumnMemory: state.enableHiddenColumnMemory }), // Don't persist logs } ) diff --git a/internal/app/app.go b/internal/app/app.go index 0709a27..4a0aff9 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "net" + "net/url" + "os" "strings" "sync" "time" @@ -218,6 +220,7 @@ func wrapConnectError(config connection.ConnectionConfig, err error) error { if err == nil { return nil } + err = sanitizeMongoConnectErrorLabel(config, err) var netErr net.Error if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { @@ -231,6 +234,73 @@ func wrapConnectError(config connection.ConnectionConfig, err error) error { return withLogHint{err: err, logPath: logger.Path()} } +type errorMessageOverride struct { + message string + cause error +} + +func (e errorMessageOverride) Error() string { + return e.message +} + +func (e errorMessageOverride) Unwrap() error { + return e.cause +} + +func sanitizeMongoConnectErrorLabel(config connection.ConnectionConfig, err error) error { + if err == nil { + return nil + } + if strings.ToLower(strings.TrimSpace(config.Type)) != "mongodb" { + return err + } + if mongoConnectUsesTLS(config) { + return err + } + original := err.Error() + rewritten := strings.ReplaceAll(original, "SSL 主库凭据", "主库凭据") + rewritten = strings.ReplaceAll(rewritten, "SSL 从库凭据", "从库凭据") + if rewritten == original { + return err + } + return errorMessageOverride{ + message: rewritten, + cause: err, + } +} + +func mongoConnectUsesTLS(config connection.ConnectionConfig) bool { + if config.UseSSL { + return true + } + uriText := strings.TrimSpace(config.URI) + if uriText == "" { + return false + } + parsed, err := url.Parse(uriText) + if err != nil { + return false + } + for _, key := range []string{"tls", "ssl"} { + if enabled, known := parseMongoBool(parsed.Query().Get(key)); known { + return enabled + } + } + return strings.EqualFold(strings.TrimSpace(parsed.Scheme), "mongodb+srv") +} + +func parseMongoBool(raw string) (enabled bool, known bool) { + value := strings.ToLower(strings.TrimSpace(raw)) + switch value { + case "1", "true", "t", "yes", "y", "on", "required": + return true, true + case "0", "false", "f", "no", "n", "off", "disable", "disabled": + return false, true + default: + return false, false + } +} + type withLogHint struct { err error logPath string @@ -238,10 +308,15 @@ type withLogHint struct { func (e withLogHint) Error() string { message := normalizeErrorMessage(e.err) - if strings.TrimSpace(e.logPath) == "" { + path := strings.TrimSpace(e.logPath) + if path == "" { return message } - return fmt.Sprintf("%s(详细日志:%s)", message, e.logPath) + info, statErr := os.Stat(path) + if statErr != nil || info.IsDir() || info.Size() <= 0 { + return message + } + return fmt.Sprintf("%s(详细日志:%s)", message, path) } func (e withLogHint) Unwrap() error { diff --git a/internal/app/app_connect_error_test.go b/internal/app/app_connect_error_test.go new file mode 100644 index 0000000..36bb99e --- /dev/null +++ b/internal/app/app_connect_error_test.go @@ -0,0 +1,84 @@ +package app + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestWrapConnectError_MongoNoSSL_RemovesMisleadingSSLLabel(t *testing.T) { + config := connection.ConnectionConfig{ + Type: "mongodb", + UseSSL: false, + } + sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error") + + wrapped := wrapConnectError(config, sourceErr) + text := wrapped.Error() + if strings.Contains(text, "SSL 主库凭据") { + t.Fatalf("expected ssl label to be removed when TLS disabled, got: %s", text) + } + if !strings.Contains(text, "主库凭据验证失败") { + t.Fatalf("expected auth label to remain, got: %s", text) + } +} + +func TestWrapConnectError_MongoURIForcesTLS_KeepsSSLLabel(t *testing.T) { + config := connection.ConnectionConfig{ + Type: "mongodb", + UseSSL: false, + URI: "mongodb://user:pass@127.0.0.1:27017/admin?tls=true", + } + sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error") + + wrapped := wrapConnectError(config, sourceErr) + text := wrapped.Error() + if !strings.Contains(text, "SSL 主库凭据") { + t.Fatalf("expected ssl label to remain when URI enables TLS, got: %s", text) + } +} + +func TestWrapConnectError_MongoSRVDefaultTLS_KeepsSSLLabel(t *testing.T) { + config := connection.ConnectionConfig{ + Type: "mongodb", + UseSSL: false, + URI: "mongodb+srv://user:pass@cluster0.example.com/admin", + } + sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error") + + wrapped := wrapConnectError(config, sourceErr) + text := wrapped.Error() + if !strings.Contains(text, "SSL 主库凭据") { + t.Fatalf("expected ssl label to remain for mongodb+srv default TLS, got: %s", text) + } +} + +func TestWithLogHintError_OmitEmptyLogPath(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "gonavi.log") + if err := os.WriteFile(logPath, nil, 0o644); err != nil { + t.Fatalf("write empty log failed: %v", err) + } + err := withLogHint{err: errors.New("连接失败"), logPath: logPath} + text := err.Error() + if strings.Contains(text, "详细日志:") { + t.Fatalf("expected no log hint for empty file, got: %s", text) + } +} + +func TestWithLogHintError_IncludeNonEmptyLogPath(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "gonavi.log") + if err := os.WriteFile(logPath, []byte("log entry\n"), 0o644); err != nil { + t.Fatalf("write log failed: %v", err) + } + err := withLogHint{err: errors.New("连接失败"), logPath: logPath} + text := err.Error() + if !strings.Contains(text, "详细日志:"+logPath) { + t.Fatalf("expected log hint with path, got: %s", text) + } +} diff --git a/internal/app/db_proxy.go b/internal/app/db_proxy.go index e3228b6..14af069 100644 --- a/internal/app/db_proxy.go +++ b/internal/app/db_proxy.go @@ -73,8 +73,8 @@ func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.Con // 文件型/自定义 DSN 类型不走标准 host:port,不在此层改写。 return config, nil } - if normalizedType == "mongodb" && config.MongoSRV { - // Mongo SRV 由驱动侧 Dialer 处理代理,避免破坏 DNS SRV 拓扑发现。 + if normalizedType == "mongodb" { + // MongoDB 统一由驱动侧 Dialer 处理代理,保留原始目标地址,避免将连接目标改写为本地转发地址。 return config, nil } diff --git a/internal/app/db_proxy_test.go b/internal/app/db_proxy_test.go new file mode 100644 index 0000000..5d44170 --- /dev/null +++ b/internal/app/db_proxy_test.go @@ -0,0 +1,64 @@ +package app + +import ( + "reflect" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestResolveDialConfigWithProxy_MongoKeepsTargetAddress(t *testing.T) { + hosts := []string{"10.20.30.40:27017", "10.20.30.41:27017"} + raw := connection.ConnectionConfig{ + Type: "mongodb", + Host: "10.20.30.40", + Port: 27017, + UseProxy: true, + Proxy: connection.ProxyConfig{ + Type: "socks5", + Host: "127.0.0.1", + Port: 1080, + }, + Hosts: hosts, + } + + got, err := resolveDialConfigWithProxy(raw) + if err != nil { + t.Fatalf("resolveDialConfigWithProxy returned error: %v", err) + } + if got.Host != raw.Host || got.Port != raw.Port { + t.Fatalf("mongo target address should be kept, got=%s:%d want=%s:%d", got.Host, got.Port, raw.Host, raw.Port) + } + if !got.UseProxy { + t.Fatalf("mongo should keep UseProxy=true for driver-level dialer") + } + if !reflect.DeepEqual(got.Hosts, hosts) { + t.Fatalf("mongo hosts should be kept, got=%v want=%v", got.Hosts, hosts) + } +} + +func TestResolveDialConfigWithProxy_MongoSRVKeepsTargetAddress(t *testing.T) { + raw := connection.ConnectionConfig{ + Type: "mongodb", + Host: "cluster0.example.com", + Port: 27017, + MongoSRV: true, + UseProxy: true, + Proxy: connection.ProxyConfig{ + Type: "http", + Host: "127.0.0.1", + Port: 7890, + }, + } + + got, err := resolveDialConfigWithProxy(raw) + if err != nil { + t.Fatalf("resolveDialConfigWithProxy returned error: %v", err) + } + if got.Host != raw.Host || got.Port != raw.Port { + t.Fatalf("mongo SRV target address should be kept, got=%s:%d want=%s:%d", got.Host, got.Port, raw.Host, raw.Port) + } + if !got.UseProxy { + t.Fatalf("mongo SRV should keep UseProxy=true for driver-level dialer") + } +} diff --git a/internal/app/global_proxy.go b/internal/app/global_proxy.go index 4361782..016fb26 100644 --- a/internal/app/global_proxy.go +++ b/internal/app/global_proxy.go @@ -72,25 +72,30 @@ func setGlobalProxyConfig(enabled bool, proxyConfig connection.ProxyConfig) (glo } func (a *App) ConfigureGlobalProxy(enabled bool, proxyConfig connection.ProxyConfig) connection.QueryResult { + before := currentGlobalProxyConfig() snapshot, err := setGlobalProxyConfig(enabled, proxyConfig) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} } - if snapshot.Enabled { - authState := "" - if strings.TrimSpace(snapshot.Proxy.User) != "" { - authState = "(认证:已配置)" + // 前端可能在同一配置下重复触发同步(例如严格模式或状态回放), + // 这里做幂等日志,避免重复刷屏。 + if !globalProxySnapshotEqual(before, snapshot) { + if snapshot.Enabled { + authState := "" + if strings.TrimSpace(snapshot.Proxy.User) != "" { + authState = "(认证:已配置)" + } + logger.Infof( + "全局代理已启用:%s://%s:%d%s", + strings.ToLower(strings.TrimSpace(snapshot.Proxy.Type)), + strings.TrimSpace(snapshot.Proxy.Host), + snapshot.Proxy.Port, + authState, + ) + } else { + logger.Infof("全局代理已关闭") } - logger.Infof( - "全局代理已启用:%s://%s:%d%s", - strings.ToLower(strings.TrimSpace(snapshot.Proxy.Type)), - strings.TrimSpace(snapshot.Proxy.Host), - snapshot.Proxy.Port, - authState, - ) - } else { - logger.Infof("全局代理已关闭") } return connection.QueryResult{ @@ -100,6 +105,24 @@ func (a *App) ConfigureGlobalProxy(enabled bool, proxyConfig connection.ProxyCon } } +func globalProxySnapshotEqual(a, b globalProxySnapshot) bool { + if a.Enabled != b.Enabled { + return false + } + if !a.Enabled { + return true + } + return proxyConfigEqual(a.Proxy, b.Proxy) +} + +func proxyConfigEqual(a, b connection.ProxyConfig) bool { + return strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) && + strings.TrimSpace(a.Host) == strings.TrimSpace(b.Host) && + a.Port == b.Port && + strings.TrimSpace(a.User) == strings.TrimSpace(b.User) && + a.Password == b.Password +} + func (a *App) GetGlobalProxyConfig() connection.QueryResult { return connection.QueryResult{ Success: true, diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index b28109f..f411653 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -13,6 +13,16 @@ import ( "GoNavi-Wails/internal/utils" ) +const testConnectionTimeoutUpperBoundSeconds = 12 + +func normalizeTestConnectionConfig(config connection.ConnectionConfig) connection.ConnectionConfig { + normalized := config + if normalized.Timeout <= 0 || normalized.Timeout > testConnectionTimeoutUpperBoundSeconds { + normalized.Timeout = testConnectionTimeoutUpperBoundSeconds + } + return normalized +} + // Generic DB Methods func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResult { @@ -28,13 +38,16 @@ func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResu } func (a *App) TestConnection(config connection.ConnectionConfig) connection.QueryResult { - _, err := a.getDatabaseForcePing(config) + testConfig := normalizeTestConnectionConfig(config) + started := time.Now() + logger.Infof("TestConnection 开始:%s", formatConnSummary(testConfig)) + _, err := a.getDatabaseForcePing(testConfig) if err != nil { - logger.Error(err, "TestConnection 连接测试失败:%s", formatConnSummary(config)) + logger.Error(err, "TestConnection 连接测试失败:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig)) return connection.QueryResult{Success: false, Message: err.Error()} } - logger.Infof("TestConnection 连接测试成功:%s", formatConnSummary(config)) + logger.Infof("TestConnection 连接测试成功:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig)) return connection.QueryResult{Success: true, Message: "连接成功"} } diff --git a/internal/app/methods_db_timeout_test.go b/internal/app/methods_db_timeout_test.go new file mode 100644 index 0000000..d6cf867 --- /dev/null +++ b/internal/app/methods_db_timeout_test.go @@ -0,0 +1,31 @@ +package app + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestNormalizeTestConnectionConfig_DefaultToUpperBound(t *testing.T) { + config := connection.ConnectionConfig{Type: "mongodb", Timeout: 0} + got := normalizeTestConnectionConfig(config) + if got.Timeout != testConnectionTimeoutUpperBoundSeconds { + t.Fatalf("expected timeout=%d, got=%d", testConnectionTimeoutUpperBoundSeconds, got.Timeout) + } +} + +func TestNormalizeTestConnectionConfig_KeepSmallerTimeout(t *testing.T) { + config := connection.ConnectionConfig{Type: "mongodb", Timeout: 6} + got := normalizeTestConnectionConfig(config) + if got.Timeout != 6 { + t.Fatalf("expected timeout=6, got=%d", got.Timeout) + } +} + +func TestNormalizeTestConnectionConfig_ClampLargeTimeout(t *testing.T) { + config := connection.ConnectionConfig{Type: "mongodb", Timeout: 60} + got := normalizeTestConnectionConfig(config) + if got.Timeout != testConnectionTimeoutUpperBoundSeconds { + t.Fatalf("expected timeout=%d, got=%d", testConnectionTimeoutUpperBoundSeconds, got.Timeout) + } +} diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index 07a13cc..ca7ce8c 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -2792,6 +2792,7 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut driverType := normalizeDriverType(definition.Type) displayName := resolveDriverDisplayName(definition) forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion) + preferSourceBuildBeforeDownload := shouldPreferSourceBuildBeforeDownload(driverType, selectedVersion) skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion) info, err := os.Stat(executablePath) @@ -2799,11 +2800,10 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil { _ = os.Remove(executablePath) } else { - hash, hashErr := hashFileSHA256(executablePath) - if hashErr != nil { - return "", "", fmt.Errorf("读取已安装 %s 驱动代理摘要失败:%w", displayName, hashErr) + // 用户点击“安装/重装”时应强制刷新驱动代理,避免沿用旧二进制导致修复不生效。 + if removeErr := os.Remove(executablePath); removeErr != nil { + return "", "", fmt.Errorf("清理已安装 %s 驱动代理失败:%w", displayName, removeErr) } - return fmt.Sprintf("local://existing/%s-driver-agent", driverType), hash, nil } } if err == nil && info.IsDir() { @@ -2834,6 +2834,22 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut } var downloadErrs []string + var sourceBuildAttempted bool + var sourceBuildErr error + + if !forceSourceBuild && preferSourceBuildBeforeDownload { + sourceBuildAttempted = true + if a != nil { + a.emitDriverDownloadProgress(driverType, "downloading", 16, 100, fmt.Sprintf("优先使用本地源码构建 %s 驱动代理", displayName)) + } + hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion) + if buildErr == nil { + return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil + } + sourceBuildErr = buildErr + logger.Warnf("预先本地构建 %s 驱动代理失败,将继续尝试下载预编译包:%v", displayName, buildErr) + } + if !forceSourceBuild { downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL) if len(downloadURLs) > 0 { @@ -2866,9 +2882,15 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut a.emitDriverDownloadProgress(driverType, "downloading", 92, 100, "未命中预编译包,尝试开发态本地构建") } - hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion) - if buildErr == nil { - return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil + var buildErr error + if sourceBuildAttempted { + buildErr = sourceBuildErr + } else { + hash, runErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion) + buildErr = runErr + if buildErr == nil { + return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil + } } var parts []string @@ -3086,12 +3108,25 @@ func shouldForceSourceBuildForVersion(driverType string, selectedVersion string) return resolveMongoDriverMajorFromVersion(selectedVersion) == 1 } -func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool { - if normalizeDriverType(driverType) != "mongodb" { +func shouldPreferSourceBuildBeforeDownload(driverType string, selectedVersion string) bool { + _ = selectedVersion + switch normalizeDriverType(driverType) { + case "kingbase": + // 金仓迭代期优先本地源码构建,避免下载到旧版本预编译代理导致修复不生效。 + return true + default: return false } +} + +func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool { _ = selectedVersion - return true + switch normalizeDriverType(driverType) { + case "mongodb", "kingbase": + return true + default: + return false + } } func optionalDriverBuildTag(driverType string, selectedVersion string) (string, error) { diff --git a/internal/db/diros_impl.go b/internal/db/diros_impl.go index 07bed73..773b7fa 100644 --- a/internal/db/diros_impl.go +++ b/internal/db/diros_impl.go @@ -9,7 +9,6 @@ import ( "strings" "GoNavi-Wails/internal/connection" - "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" @@ -135,26 +134,26 @@ func collectDirosAddresses(config connection.ConnectionConfig) []string { return result } -func (d *DirosDB) getDSN(config connection.ConnectionConfig) string { +func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" address := normalizeMySQLAddress(config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - protocol = netName - address = normalizeMySQLAddress(config.Host, config.Port) - } else { - logger.Warnf("注册 Doris SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err) + if err != nil { + return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } + protocol = netName } timeout := getConnectTimeoutSeconds(config) tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) + return fmt.Sprintf( + "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), + ), nil } func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { @@ -192,7 +191,11 @@ func (d *DirosDB) Connect(config connection.ConnectionConfig) error { candidateConfig.Port = port candidateConfig.User, candidateConfig.Password = resolveDirosCredential(runConfig, index) - dsn := d.getDSN(candidateConfig) + dsn, err := d.getDSN(candidateConfig) + if err != nil { + errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err)) + continue + } db, err := sql.Open(dirosDriverName, dsn) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) diff --git a/internal/db/kingbase_identifier_utils.go b/internal/db/kingbase_identifier_utils.go new file mode 100644 index 0000000..f3412ac --- /dev/null +++ b/internal/db/kingbase_identifier_utils.go @@ -0,0 +1,164 @@ +package db + +import "strings" + +func normalizeKingbaseIdentCommon(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + + // 兼容被多次 JSON 序列化后的转义引号: + // \\\"schema\\\" -> \"schema\" -> "schema" + for i := 0; i < 8; i++ { + next := strings.TrimSpace(value) + next = strings.ReplaceAll(next, `\\\"`, `\"`) + next = strings.ReplaceAll(next, `\"`, `"`) + if next == value { + break + } + value = next + } + value = strings.TrimSpace(value) + + stripWrapperOnce := func(text string) string { + t := strings.TrimSpace(text) + if strings.HasPrefix(t, `\`) && len(t) > 1 { + t = strings.TrimSpace(strings.TrimPrefix(t, `\`)) + } + if strings.HasSuffix(t, `\`) && len(t) > 1 { + t = strings.TrimSpace(strings.TrimSuffix(t, `\`)) + } + if len(t) >= 4 && strings.HasPrefix(t, `\"`) && strings.HasSuffix(t, `\"`) { + return strings.TrimSpace(t[2 : len(t)-2]) + } + if len(t) >= 2 && strings.HasPrefix(t, `"`) && strings.HasSuffix(t, `"`) { + return strings.TrimSpace(t[1 : len(t)-1]) + } + if len(t) >= 2 && strings.HasPrefix(t, "`") && strings.HasSuffix(t, "`") { + return strings.TrimSpace(t[1 : len(t)-1]) + } + if len(t) >= 2 && strings.HasPrefix(t, "[") && strings.HasSuffix(t, "]") { + return strings.TrimSpace(t[1 : len(t)-1]) + } + return t + } + + for i := 0; i < 8; i++ { + next := stripWrapperOnce(value) + if next == value { + break + } + value = next + } + value = strings.TrimSpace(value) + + // 兼容错误的二次引用与残留反斜杠。 + value = strings.ReplaceAll(value, `\"`, `"`) + value = strings.ReplaceAll(value, `""`, "") + value = strings.TrimSpace(value) + + for i := 0; i < 8; i++ { + next := strings.TrimSpace(value) + changed := false + if strings.HasPrefix(next, `\`) && len(next) > 1 { + next = strings.TrimSpace(strings.TrimPrefix(next, `\`)) + changed = true + } + if strings.HasSuffix(next, `\`) && len(next) > 1 { + next = strings.TrimSpace(strings.TrimSuffix(next, `\`)) + changed = true + } + if !changed || next == value { + break + } + value = next + } + + return strings.TrimSpace(value) +} + +func splitKingbaseQualifiedNameCommon(raw string) (schema string, table string) { + text := strings.TrimSpace(raw) + if text == "" { + return "", "" + } + + sep := findKingbaseQualifiedSeparator(text) + if sep < 0 { + return "", normalizeKingbaseIdentCommon(text) + } + + schemaPart := normalizeKingbaseIdentCommon(text[:sep]) + tablePart := normalizeKingbaseIdentCommon(text[sep+1:]) + + if tablePart == "" { + if schemaPart == "" { + return "", normalizeKingbaseIdentCommon(text) + } + return "", schemaPart + } + if schemaPart == "" { + return "", tablePart + } + return schemaPart, tablePart +} + +func findKingbaseQualifiedSeparator(raw string) int { + inDouble := false + inBacktick := false + inBracket := false + escaped := false + + for i := 0; i < len(raw); i++ { + ch := raw[i] + if escaped { + escaped = false + continue + } + + if ch == '\\' { + escaped = true + continue + } + + if inDouble { + if ch == '"' { + // SQL 双引号转义:"" 代表字面量 " + if i+1 < len(raw) && raw[i+1] == '"' { + i++ + continue + } + inDouble = false + } + continue + } + + if inBacktick { + if ch == '`' { + inBacktick = false + } + continue + } + + if inBracket { + if ch == ']' { + inBracket = false + } + continue + } + + switch ch { + case '"': + inDouble = true + case '`': + inBacktick = true + case '[': + inBracket = true + case '.': + return i + } + } + + return -1 +} diff --git a/internal/db/kingbase_identifier_utils_test.go b/internal/db/kingbase_identifier_utils_test.go new file mode 100644 index 0000000..69e2b2e --- /dev/null +++ b/internal/db/kingbase_identifier_utils_test.go @@ -0,0 +1,52 @@ +package db + +import "testing" + +func TestNormalizeKingbaseIdentCommon(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server", want: "ldf_server"}, + {name: "quoted", in: `"ldf_server"`, want: "ldf_server"}, + {name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"}, + {name: "double escaped quoted", in: `\\\"ldf_server\\\"`, want: "ldf_server"}, + {name: "double quoted", in: `""ldf_server""`, want: "ldf_server"}, + {name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"}, + {name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeKingbaseIdentCommon(tt.in); got != tt.want { + t.Fatalf("normalizeKingbaseIdentCommon(%q)=%q,want=%q", tt.in, got, tt.want) + } + }) + } +} + +func TestSplitKingbaseQualifiedNameCommon(t *testing.T) { + tests := []struct { + name string + in string + wantSchema string + wantTable string + }{ + {name: "plain", in: "ldf_server.andon_events", wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "quoted", in: `"ldf_server"."andon_events"`, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "escaped quoted", in: `\"ldf_server\".\"andon_events\"`, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "double escaped quoted", in: `\\\"ldf_server\\\".\\\"andon_events\\\"`, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "space around dot", in: ` "ldf_server" . "andon_events" `, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "table only", in: "andon_events", wantSchema: "", wantTable: "andon_events"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotSchema, gotTable := splitKingbaseQualifiedNameCommon(tt.in) + if gotSchema != tt.wantSchema || gotTable != tt.wantTable { + t.Fatalf("splitKingbaseQualifiedNameCommon(%q)=(%q,%q),want=(%q,%q)", tt.in, gotSchema, gotTable, tt.wantSchema, tt.wantTable) + } + }) + } +} diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 619455d..d4eda20 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -7,6 +7,7 @@ import ( "database/sql" "fmt" "net" + "regexp" "strconv" "strings" "time" @@ -136,11 +137,88 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { if idx > 0 { logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接") } + + // 获取 schema 列表以重构带有 search_path 的连接池 + searchPathStr := k.getSearchPathStr() + if searchPathStr != "" { + // 将 search_path 参数拼入 DSN + finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr) + if finalDB, err := sql.Open("kingbase", finalDSN); err == nil { + k.pingTimeout = getConnectTimeout(attempt) + finalDB.SetConnMaxLifetime(5 * time.Minute) + + // 临时将 k.conn 指向 finalDB 来做 ping 测试 + oldConn := k.conn + k.conn = finalDB + if err := k.Ping(); err == nil { + // 成功使用带 search_path 的连接池 + _ = oldConn.Close() + logger.Infof("人大金仓已配置连接级 search_path:%s", searchPathStr) + } else { + _ = finalDB.Close() + k.conn = oldConn + } + } + } + if searchPathStr != "" { + timeout := k.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + if _, err := k.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathStr)); err != nil { + logger.Warnf("人大金仓显式设置 search_path 失败:%v", err) + } else { + logger.Infof("人大金仓已设置默认 search_path:%s", searchPathStr) + } + } + return nil } return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } +// getSearchPathStr 查询当前数据库中所有用户 schema,配置 DSN 的 search_path。 +// KingBase 默认 search_path 为 "$user", public,对于自定义 schema 下的表不可见。 +func (k *KingbaseDB) getSearchPathStr() string { + if k.conn == nil { + return "" + } + + query := `SELECT nspname FROM pg_namespace + WHERE nspname NOT IN ('pg_catalog', 'information_schema') + AND nspname NOT LIKE 'pg_%' + ORDER BY nspname` + + rows, err := k.conn.Query(query) + if err != nil { + logger.Warnf("人大金仓查询用户 schema 失败,跳过 search_path 设置:%v", err) + return "" + } + defer rows.Close() + + var schemas []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + continue + } + name = strings.TrimSpace(name) + if name != "" { + // 使用 SQL 标准的双引号包裹标识符 + escaped := strings.ReplaceAll(name, `"`, `""`) + schemas = append(schemas, `"`+escaped+`"`) + } + } + + if len(schemas) == 0 { + return "" + } + + return strings.Join(schemas, ", ") +} + func (k *KingbaseDB) Close() error { // Close SSH forwarder first if exists if k.forwarder != nil { @@ -775,64 +853,63 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet } func normalizeKingbaseIdentifier(raw string) string { - value := strings.TrimSpace(raw) - if value == "" { - return "" + return normalizeKingbaseIdentCommon(raw) +} + +// kingbaseIdentNeedsQuote 判断标识符是否需要双引号包裹。 +// 与前端 sql.ts 中 needsQuote 逻辑保持一致。 +func kingbaseIdentNeedsQuote(ident string) bool { + if ident == "" { + return false } - - // 兼容 JSON/字符串转义后传入的标识符:\"schema\" -> "schema" - value = strings.ReplaceAll(value, `\"`, `"`) - value = strings.TrimSpace(value) - - // 兼容异常多重包裹引号(例如 ""schema""、""""schema"""")。 - // strings.Trim 会移除两端连续引号,迭代后可收敛到纯标识符。 - for i := 0; i < 4; i++ { - next := strings.TrimSpace(strings.Trim(value, `"`)) - if next == value { - break + // 不是合法裸标识符格式(必须以字母或下划线开头,仅含字母、数字、下划线) + if matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, ident); !matched { + return true + } + // 包含大写字母时需要引号保护(KingbaseES/PostgreSQL 默认将未加引号的标识符折叠为小写) + for _, r := range ident { + if r >= 'A' && r <= 'Z' { + return true } - value = next } + // 是 SQL 保留字 + return isKingbaseReservedWord(ident) +} - // 兼容其他方言可能残留的引用形式 - if len(value) >= 2 && strings.HasPrefix(value, "`") && strings.HasSuffix(value, "`") { - value = strings.TrimSpace(strings.Trim(value, "`")) +// isKingbaseReservedWord 检查是否为常见 SQL 保留字(简化版,与前端保持一致)。 +func isKingbaseReservedWord(ident string) bool { + switch strings.ToLower(ident) { + case "select", "from", "where", "table", "index", "user", "order", "group", "by", + "limit", "offset", "and", "or", "not", "null", "true", "false", "key", + "primary", "foreign", "references", "default", "constraint", + "create", "drop", "alter", "insert", "update", "delete", "set", "values", "into", + "join", "left", "right", "inner", "outer", "on", "as", "is", "in", "like", + "between", "case", "when", "then", "else", "end", "having", "distinct", + "all", "any", "exists", "union", "except", "intersect", + "column", "check", "unique", "with", "grant", "revoke", "trigger", + "begin", "commit", "rollback", "schema", "database", "view", "function", + "procedure", "sequence", "type", "domain", "role", "session", "current", + "authorization", "cross", "full", "natural", "some", "cast", "fetch", + "for", "to", "do", "if", "return", "returns", "declare", "cursor", "server", "owner": + return true } - if len(value) >= 2 && strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { - value = strings.TrimSpace(value[1 : len(value)-1]) - } - - return value + return false } func quoteKingbaseIdent(name string) string { n := normalizeKingbaseIdentifier(name) - n = strings.ReplaceAll(n, `"`, `""`) if n == "" { return "\"\"" } + if !kingbaseIdentNeedsQuote(n) { + return n + } + n = strings.ReplaceAll(n, `"`, `""`) return `"` + n + `"` } func splitKingbaseQualifiedTable(tableName string) (schema string, table string) { - raw := strings.TrimSpace(tableName) - if raw == "" { - return "", "" - } - - if parts := strings.SplitN(raw, ".", 2); len(parts) == 2 { - schema = normalizeKingbaseIdentifier(parts[0]) - table = normalizeKingbaseIdentifier(parts[1]) - if table == "" { - return "", normalizeKingbaseIdentifier(raw) - } - if schema == "" { - return "", table - } - return schema, table - } - - return "", normalizeKingbaseIdentifier(raw) + return splitKingbaseQualifiedNameCommon(tableName) } func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { diff --git a/internal/db/kingbase_impl_test.go b/internal/db/kingbase_impl_test.go index eca6eaa..8b0d6f5 100644 --- a/internal/db/kingbase_impl_test.go +++ b/internal/db/kingbase_impl_test.go @@ -15,8 +15,10 @@ func TestNormalizeKingbaseIdentifier(t *testing.T) { {name: "double quoted", in: `""ldf_server""`, want: "ldf_server"}, {name: "quad quoted", in: `""""ldf_server""""`, want: "ldf_server"}, {name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"}, + {name: "double escaped quoted", in: `\\\"ldf_server\\\"`, want: "ldf_server"}, {name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"}, {name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"}, + {name: "embedded double quotes", in: `ldf""server`, want: "ldfserver"}, } for _, tt := range tests { @@ -34,10 +36,25 @@ func TestQuoteKingbaseIdent(t *testing.T) { in string want string }{ - {name: "plain", in: "ldf_server", want: `"ldf_server"`}, - {name: "double quoted", in: `""ldf_server""`, want: `"ldf_server"`}, - {name: "escaped quoted", in: `\"ldf_server\"`, want: `"ldf_server"`}, + // 纯小写+下划线:不加引号 + {name: "plain lowercase", in: "ldf_server", want: "ldf_server"}, + {name: "plain lowercase 2", in: "bcs_barcode", want: "bcs_barcode"}, + {name: "double quoted input", in: `""ldf_server""`, want: "ldf_server"}, + {name: "escaped quoted input", in: `\"ldf_server\"`, want: "ldf_server"}, + // 含大写字母:加引号 + {name: "uppercase", in: "LDF_Server", want: `"LDF_Server"`}, + {name: "mixed case", in: "myTable", want: `"myTable"`}, + // SQL 保留字:加引号 + {name: "reserved word order", in: "order", want: `"order"`}, + {name: "reserved word user", in: "user", want: `"user"`}, + {name: "reserved word table", in: "table", want: `"table"`}, + {name: "reserved word select", in: "select", want: `"select"`}, + // 含特殊字符:加引号 + {name: "with hyphen", in: "my-table", want: `"my-table"`}, + {name: "with space", in: "my table", want: `"my table"`}, {name: "with embedded quote", in: `ab"cd`, want: `"ab""cd"`}, + // 空值 + {name: "empty", in: "", want: `""`}, } for _, tt := range tests { @@ -49,6 +66,31 @@ func TestQuoteKingbaseIdent(t *testing.T) { } } +func TestKingbaseIdentNeedsQuote(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + {name: "plain lowercase", in: "ldf_server", want: false}, + {name: "starts with underscore", in: "_col", want: false}, + {name: "with digits", in: "col123", want: false}, + {name: "uppercase", in: "MyTable", want: true}, + {name: "reserved word", in: "order", want: true}, + {name: "with hyphen", in: "my-col", want: true}, + {name: "starts with digit", in: "123col", want: true}, + {name: "empty", in: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := kingbaseIdentNeedsQuote(tt.in); got != tt.want { + t.Fatalf("kingbaseIdentNeedsQuote(%q) = %v, want %v", tt.in, got, tt.want) + } + }) + } +} + func TestSplitKingbaseQualifiedTable(t *testing.T) { tests := []struct { name string @@ -59,6 +101,7 @@ func TestSplitKingbaseQualifiedTable(t *testing.T) { {name: "plain qualified", in: "ldf_server.t_user", wantSchema: "ldf_server", wantTable: "t_user"}, {name: "double quoted qualified", in: `""ldf_server"".""t_user""`, wantSchema: "ldf_server", wantTable: "t_user"}, {name: "escaped qualified", in: `\"ldf_server\".\"t_user\"`, wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "double escaped qualified", in: `\\\"ldf_server\\\".\\\"t_user\\\"`, wantSchema: "ldf_server", wantTable: "t_user"}, {name: "bracket qualified", in: "[ldf_server].[t_user]", wantSchema: "ldf_server", wantTable: "t_user"}, {name: "table only", in: `""t_user""`, wantSchema: "", wantTable: "t_user"}, } diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 6a36400..65b9cc3 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -11,7 +11,6 @@ import ( "time" "GoNavi-Wails/internal/connection" - "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" @@ -25,30 +24,33 @@ type MariaDB struct { pingTimeout time.Duration } -func (m *MariaDB) getDSN(config connection.ConnectionConfig) string { +func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" address := fmt.Sprintf("%s:%d", config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - protocol = netName - address = fmt.Sprintf("%s:%d", config.Host, config.Port) - } else { - logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err) + if err != nil { + return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } + protocol = netName } timeout := getConnectTimeoutSeconds(config) tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) + return fmt.Sprintf( + "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), + ), nil } func (m *MariaDB) Connect(config connection.ConnectionConfig) error { - dsn := m.getDSN(config) + dsn, err := m.getDSN(config) + if err != nil { + return err + } db, err := sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index 27ac0c7..dff4644 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -151,10 +151,14 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf } } - if len(config.Hosts) == 0 && len(hostsFromURI) > 0 { + explicitHost := strings.TrimSpace(config.Host) != "" + explicitHosts := len(config.Hosts) > 0 + + // 显式填写的 host/hosts 优先级高于 URI,避免表单 host 被 URI 中的 localhost 覆盖。 + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { config.Hosts = hostsFromURI } - if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 { + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort) if ok { config.Host = host @@ -281,9 +285,44 @@ func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.Con return attempts } +func mongoURIForcesTLS(uriText string) bool { + trimmed := strings.TrimSpace(uriText) + if trimmed == "" { + return false + } + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + query := parsed.Query() + for _, key := range []string{"tls", "ssl"} { + value := strings.ToLower(strings.TrimSpace(query.Get(key))) + switch value { + case "1", "true", "t", "yes", "y", "required": + return true + } + } + return false +} + +func mongoAttemptSSLLabel(config connection.ConnectionConfig, fallbackToPlain bool) string { + if fallbackToPlain { + return "明文回退" + } + if mongoURIForcesTLS(config.URI) { + return "SSL" + } + enabled, _ := resolveMongoTLSSettings(config) + if enabled { + return "SSL" + } + return "明文" +} + func (m *MongoDB) Connect(config connection.ConnectionConfig) error { runConfig := applyMongoURI(config) connectConfig := runConfig + sshRouteHint := "" if runConfig.UseSSH && runConfig.MongoSRV { return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道") @@ -324,6 +363,7 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { localConfig.URI = "" localConfig.Hosts = []string{normalizeMongoAddress(host, port)} connectConfig = localConfig + sshRouteHint = fmt.Sprintf("SSH隧道 %s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) } @@ -337,20 +377,32 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { if shouldTrySSLPreferredFallback(connectConfig) { sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig)) } + totalAttempts := 0 + for _, attemptConfig := range sslAttempts { + totalAttempts += len(buildMongoAuthAttempts(attemptConfig)) + } + attemptNo := 0 var errorDetails []string for sslIndex, sslConfig := range sslAttempts { - sslLabel := "SSL" - if sslIndex > 0 { - sslLabel = "明文回退" - } + sslLabel := mongoAttemptSSLLabel(sslConfig, sslIndex > 0) attemptConfigs := buildMongoAuthAttempts(sslConfig) for index, attemptConfig := range attemptConfigs { + attemptNo++ authLabel := "主库凭据" if index > 0 { authLabel = "从库凭据" } + targets := collectMongoSeeds(attemptConfig) + if len(targets) == 0 { + targets = append(targets, normalizeMongoAddress(attemptConfig.Host, attemptConfig.Port)) + } + attemptStarted := time.Now() + logger.Infof( + "MongoDB 连接尝试:%d/%d 模式=%s 凭据=%s 目标=%s 代理=%t", + attemptNo, totalAttempts, sslLabel, authLabel, strings.Join(targets, ","), attemptConfig.UseProxy, + ) if sslIndex > 0 { attemptConfig.URI = "" @@ -369,7 +421,13 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { } client, err := mongo.Connect(clientOpts) if err != nil { - errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB 连接尝试失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } @@ -379,9 +437,17 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { _ = client.Disconnect(ctx) cancel() m.client = nil - errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB 连接尝试验证失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } + logger.Infof("MongoDB 连接尝试成功:%d/%d 模式=%s 凭据=%s 耗时=%s", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond)) if sslIndex > 0 { logger.Warnf("MongoDB SSL 优先连接失败,已回退至明文连接") } diff --git a/internal/db/mongodb_impl_uri_test.go b/internal/db/mongodb_impl_uri_test.go new file mode 100644 index 0000000..020b293 --- /dev/null +++ b/internal/db/mongodb_impl_uri_test.go @@ -0,0 +1,39 @@ +//go:build gonavi_full_drivers || gonavi_mongodb_driver + +package db + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestApplyMongoURI_ExplicitHostDoesNotAdoptURIHosts(t *testing.T) { + config := connection.ConnectionConfig{ + Host: "10.10.10.10", + Port: 27017, + URI: "mongodb://localhost:27017/admin", + } + + got := applyMongoURI(config) + if got.Host != "10.10.10.10" { + t.Fatalf("expected host to remain explicit, got %q", got.Host) + } + if len(got.Hosts) != 0 { + t.Fatalf("expected hosts to remain empty when explicit host exists, got %v", got.Hosts) + } +} + +func TestApplyMongoURI_ExplicitHostsDoesNotAdoptURIHosts(t *testing.T) { + config := connection.ConnectionConfig{ + Host: "10.10.10.10", + Port: 27017, + Hosts: []string{"10.10.10.10:27017", "10.10.10.11:27017"}, + URI: "mongodb://localhost:27017,localhost:27018/admin?replicaSet=rs0", + } + + got := applyMongoURI(config) + if len(got.Hosts) != 2 || got.Hosts[0] != "10.10.10.10:27017" { + t.Fatalf("expected explicit hosts to stay untouched, got %v", got.Hosts) + } +} diff --git a/internal/db/mongodb_impl_v1.go b/internal/db/mongodb_impl_v1.go index e3aa5b4..60d4fb2 100644 --- a/internal/db/mongodb_impl_v1.go +++ b/internal/db/mongodb_impl_v1.go @@ -152,10 +152,14 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf } } - if len(config.Hosts) == 0 && len(hostsFromURI) > 0 { + explicitHost := strings.TrimSpace(config.Host) != "" + explicitHosts := len(config.Hosts) > 0 + + // 显式填写的 host/hosts 优先级高于 URI,避免表单 host 被 URI 中的 localhost 覆盖。 + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { config.Hosts = hostsFromURI } - if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 { + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort) if ok { config.Host = host @@ -282,9 +286,44 @@ func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.Con return attempts } +func mongoURIForcesTLS(uriText string) bool { + trimmed := strings.TrimSpace(uriText) + if trimmed == "" { + return false + } + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + query := parsed.Query() + for _, key := range []string{"tls", "ssl"} { + value := strings.ToLower(strings.TrimSpace(query.Get(key))) + switch value { + case "1", "true", "t", "yes", "y", "required": + return true + } + } + return false +} + +func mongoAttemptSSLLabel(config connection.ConnectionConfig, fallbackToPlain bool) string { + if fallbackToPlain { + return "明文回退" + } + if mongoURIForcesTLS(config.URI) { + return "SSL" + } + enabled, _ := resolveMongoTLSSettings(config) + if enabled { + return "SSL" + } + return "明文" +} + func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { runConfig := applyMongoURI(config) connectConfig := runConfig + sshRouteHint := "" if runConfig.UseSSH && runConfig.MongoSRV { return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道") @@ -325,6 +364,7 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { localConfig.URI = "" localConfig.Hosts = []string{normalizeMongoAddress(host, port)} connectConfig = localConfig + sshRouteHint = fmt.Sprintf("SSH隧道 %s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) } @@ -338,20 +378,32 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { if shouldTrySSLPreferredFallback(connectConfig) { sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig)) } + totalAttempts := 0 + for _, attemptConfig := range sslAttempts { + totalAttempts += len(buildMongoAuthAttempts(attemptConfig)) + } + attemptNo := 0 var errorDetails []string for sslIndex, sslConfig := range sslAttempts { - sslLabel := "SSL" - if sslIndex > 0 { - sslLabel = "明文回退" - } + sslLabel := mongoAttemptSSLLabel(sslConfig, sslIndex > 0) attemptConfigs := buildMongoAuthAttempts(sslConfig) for index, attemptConfig := range attemptConfigs { + attemptNo++ authLabel := "主库凭据" if index > 0 { authLabel = "从库凭据" } + targets := collectMongoSeeds(attemptConfig) + if len(targets) == 0 { + targets = append(targets, normalizeMongoAddress(attemptConfig.Host, attemptConfig.Port)) + } + attemptStarted := time.Now() + logger.Infof( + "MongoDB(v1) 连接尝试:%d/%d 模式=%s 凭据=%s 目标=%s 代理=%t", + attemptNo, totalAttempts, sslLabel, authLabel, strings.Join(targets, ","), attemptConfig.UseProxy, + ) if sslIndex > 0 { attemptConfig.URI = "" @@ -372,7 +424,13 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { client, err := mongo.Connect(connectCtx, clientOpts) connectCancel() if err != nil { - errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB(v1) 连接尝试失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } @@ -382,9 +440,17 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { _ = client.Disconnect(ctx) cancel() m.client = nil - errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB(v1) 连接尝试验证失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } + logger.Infof("MongoDB(v1) 连接尝试成功:%d/%d 模式=%s 凭据=%s 耗时=%s", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond)) if sslIndex > 0 { logger.Warnf("MongoDB(v1) SSL 优先连接失败,已回退至明文连接") } diff --git a/internal/db/mongodb_impl_v1_uri_test.go b/internal/db/mongodb_impl_v1_uri_test.go new file mode 100644 index 0000000..8860db2 --- /dev/null +++ b/internal/db/mongodb_impl_v1_uri_test.go @@ -0,0 +1,25 @@ +//go:build gonavi_mongodb_driver_v1 + +package db + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestApplyMongoURIV1_ExplicitHostDoesNotAdoptURIHosts(t *testing.T) { + config := connection.ConnectionConfig{ + Host: "10.10.10.10", + Port: 27017, + URI: "mongodb://localhost:27017/admin", + } + + got := applyMongoURI(config) + if got.Host != "10.10.10.10" { + t.Fatalf("expected host to remain explicit, got %q", got.Host) + } + if len(got.Hosts) != 0 { + t.Fatalf("expected hosts to remain empty when explicit host exists, got %v", got.Hosts) + } +} diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 5095f1c..32b63cc 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -169,26 +169,26 @@ func collectMySQLAddresses(config connection.ConnectionConfig) []string { return result } -func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string { +func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" address := normalizeMySQLAddress(config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - protocol = netName - address = normalizeMySQLAddress(config.Host, config.Port) - } else { - logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err) + if err != nil { + return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } + protocol = netName } timeout := getConnectTimeoutSeconds(config) tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) + return fmt.Sprintf( + "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), + ), nil } func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { @@ -226,7 +226,11 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error { candidateConfig.Port = port candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index) - dsn := m.getDSN(candidateConfig) + dsn, err := m.getDSN(candidateConfig) + if err != nil { + errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err)) + continue + } db, err := sql.Open("mysql", dsn) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) diff --git a/internal/db/mysql_ssh_test.go b/internal/db/mysql_ssh_test.go new file mode 100644 index 0000000..673639c --- /dev/null +++ b/internal/db/mysql_ssh_test.go @@ -0,0 +1,26 @@ +package db + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestMySQLDSN_UseSSH_ShouldFailWhenSSHInvalid(t *testing.T) { + m := &MySQLDB{} + _, err := m.getDSN(connection.ConnectionConfig{ + Host: "127.0.0.1", + Port: 3306, + User: "root", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "127.0.0.1", + Port: 0, // invalid port, should fail immediately + User: "bad", + Password: "bad", + }, + }) + if err == nil { + t.Fatalf("expected error when UseSSH=true and SSH config invalid") + } +} diff --git a/internal/db/optional_driver_agent_impl.go b/internal/db/optional_driver_agent_impl.go index 2579b7c..07fd7d3 100644 --- a/internal/db/optional_driver_agent_impl.go +++ b/internal/db/optional_driver_agent_impl.go @@ -9,6 +9,7 @@ import ( "io" "os" "os/exec" + "reflect" "runtime" "strings" "sync" @@ -145,6 +146,7 @@ func (c *optionalDriverAgentClient) captureStderr(stderr io.Reader) { if line == "" { continue } + logger.Warnf("%s 驱动代理 stderr: %s", driverDisplayName(c.driver), line) c.stderrMu.Lock() if c.stderr.Len() > 0 { c.stderr.WriteString(" | ") @@ -268,6 +270,7 @@ func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) erro return err } d.client = client + d.ensureKingbaseSearchPath(config) return nil } @@ -488,6 +491,16 @@ func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connectio if err != nil { return err } + if strings.EqualFold(d.driverType, "kingbase") { + if normalized := normalizeKingbaseAgentTableName(tableName); normalized != "" { + tableName = normalized + } + if normalized, normErr := d.normalizeKingbaseAgentChangeSet(tableName, changes); normErr == nil { + changes = normalized + } else { + logger.Warnf("Kingbase ApplyChanges 字段名规范化失败:%v", normErr) + } + } return client.call(optionalAgentRequest{ Method: optionalAgentMethodApplyChanges, TableName: tableName, @@ -502,6 +515,269 @@ func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, err return d.client, nil } +func (d *OptionalDriverAgentDB) ensureKingbaseSearchPath(config connection.ConnectionConfig) { + if !strings.EqualFold(d.driverType, "kingbase") { + return + } + client, err := d.requireClient() + if err != nil || client == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + schemas, err := d.listKingbaseSchemas(ctx) + if err != nil || len(schemas) == 0 { + if err != nil { + logger.Warnf("人大金仓驱动代理探测 schema 失败:%v", err) + } + return + } + + searchPath := buildKingbaseSearchPathFromSchemas(schemas) + if strings.TrimSpace(searchPath) == "" { + return + } + + if _, err := d.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPath)); err != nil { + logger.Warnf("人大金仓驱动代理设置 search_path 失败:%v", err) + return + } + logger.Infof("人大金仓驱动代理已设置默认 search_path:%s", searchPath) +} + +func (d *OptionalDriverAgentDB) listKingbaseSchemas(ctx context.Context) ([]string, error) { + query := `SELECT nspname FROM pg_namespace + WHERE nspname NOT IN ('pg_catalog', 'information_schema') + AND nspname NOT LIKE 'pg_%' + ORDER BY nspname` + rows, _, err := d.QueryContext(ctx, query) + if err != nil { + return nil, err + } + + schemas := make([]string, 0, len(rows)) + for _, row := range rows { + for key, val := range row { + if strings.EqualFold(key, "nspname") || strings.EqualFold(key, "schema") { + name := strings.TrimSpace(fmt.Sprintf("%v", val)) + if name != "" { + schemas = append(schemas, name) + } + break + } + } + if len(row) == 1 { + for _, val := range row { + name := strings.TrimSpace(fmt.Sprintf("%v", val)) + if name != "" { + schemas = append(schemas, name) + } + break + } + } + } + return schemas, nil +} + +func buildKingbaseSearchPathFromSchemas(schemas []string) string { + if len(schemas) == 0 { + return "" + } + seen := make(map[string]struct{}, len(schemas)+1) + parts := make([]string, 0, len(schemas)+1) + for _, name := range schemas { + trimmed := normalizeKingbaseAgentIdent(name) + if trimmed == "" { + continue + } + key := strings.ToLower(trimmed) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + parts = append(parts, quoteKingbaseAgentIdent(trimmed)) + } + if _, ok := seen["public"]; !ok { + parts = append(parts, "public") + } + return strings.Join(parts, ", ") +} + +func quoteKingbaseAgentIdent(name string) string { + n := normalizeKingbaseAgentIdent(name) + if n == "" { + return "\"\"" + } + n = strings.ReplaceAll(n, `"`, `""`) + return `"` + n + `"` +} + +func normalizeKingbaseAgentTableName(raw string) string { + schema, table := splitKingbaseQualifiedNameCommon(raw) + if table == "" { + return "" + } + if schema == "" { + return table + } + return schema + "." + table +} + +func normalizeKingbaseAgentIdent(raw string) string { + return normalizeKingbaseIdentCommon(raw) +} + +type kingbaseAgentColumnIndex struct { + exact map[string]string + compact map[string]string +} + +func buildKingbaseAgentColumnIndex(columns []string) kingbaseAgentColumnIndex { + exact := make(map[string]string, len(columns)) + compact := make(map[string]string, len(columns)) + compactSeen := make(map[string]string, len(columns)) + compactDup := make(map[string]struct{}, len(columns)) + + for _, col := range columns { + name := normalizeKingbaseAgentIdent(col) + if name == "" { + continue + } + lower := strings.ToLower(name) + if _, ok := exact[lower]; !ok { + exact[lower] = name + } + key := normalizeKingbaseAgentCompactKey(name) + if key == "" { + continue + } + if prev, ok := compactSeen[key]; ok && !strings.EqualFold(prev, name) { + compactDup[key] = struct{}{} + continue + } + compactSeen[key] = name + } + + if len(compactDup) > 0 { + for key := range compactDup { + delete(compactSeen, key) + } + } + for key, value := range compactSeen { + compact[key] = value + } + return kingbaseAgentColumnIndex{exact: exact, compact: compact} +} + +func normalizeKingbaseAgentCompactKey(raw string) string { + name := normalizeKingbaseAgentIdent(raw) + if name == "" { + return "" + } + name = strings.ToLower(strings.TrimSpace(name)) + name = strings.Join(strings.Fields(name), "") + name = strings.ReplaceAll(name, "_", "") + return name +} + +func resolveKingbaseAgentColumnName(name string, index kingbaseAgentColumnIndex) string { + cleaned := normalizeKingbaseAgentIdent(name) + if cleaned == "" { + return name + } + lower := strings.ToLower(cleaned) + if actual, ok := index.exact[lower]; ok { + return actual + } + compact := normalizeKingbaseAgentCompactKey(cleaned) + if actual, ok := index.compact[compact]; ok { + return actual + } + return cleaned +} + +func normalizeKingbaseAgentChangeSetByColumns(changes connection.ChangeSet, columns []string) (connection.ChangeSet, error) { + index := buildKingbaseAgentColumnIndex(columns) + if len(index.exact) == 0 && len(index.compact) == 0 { + return changes, nil + } + + mapRow := func(row map[string]interface{}) (map[string]interface{}, error) { + if row == nil { + return row, nil + } + out := make(map[string]interface{}, len(row)) + for key, value := range row { + nextKey := resolveKingbaseAgentColumnName(key, index) + if existing, ok := out[nextKey]; ok && !reflect.DeepEqual(existing, value) { + return nil, fmt.Errorf("duplicate mapped column %q", nextKey) + } + out[nextKey] = value + } + return out, nil + } + + next := connection.ChangeSet{ + Inserts: make([]map[string]interface{}, 0, len(changes.Inserts)), + Updates: make([]connection.UpdateRow, 0, len(changes.Updates)), + Deletes: make([]map[string]interface{}, 0, len(changes.Deletes)), + } + + for _, row := range changes.Inserts { + mapped, err := mapRow(row) + if err != nil { + return changes, err + } + next.Inserts = append(next.Inserts, mapped) + } + + for _, upd := range changes.Updates { + keys, err := mapRow(upd.Keys) + if err != nil { + return changes, err + } + values, err := mapRow(upd.Values) + if err != nil { + return changes, err + } + next.Updates = append(next.Updates, connection.UpdateRow{ + Keys: keys, + Values: values, + }) + } + + for _, row := range changes.Deletes { + mapped, err := mapRow(row) + if err != nil { + return changes, err + } + next.Deletes = append(next.Deletes, mapped) + } + + return next, nil +} + +func (d *OptionalDriverAgentDB) normalizeKingbaseAgentChangeSet(tableName string, changes connection.ChangeSet) (connection.ChangeSet, error) { + columns, err := d.GetColumns("", tableName) + if err != nil { + return changes, err + } + if len(columns) == 0 { + return changes, nil + } + names := make([]string, 0, len(columns)) + for _, col := range columns { + name := strings.TrimSpace(col.Name) + if name == "" { + continue + } + names = append(names, name) + } + return normalizeKingbaseAgentChangeSetByColumns(changes, names) +} + func timeoutMsFromContext(ctx context.Context) int64 { deadline, ok := ctx.Deadline() if !ok { diff --git a/internal/db/optional_driver_agent_impl_test.go b/internal/db/optional_driver_agent_impl_test.go index 2273a06..a79b03d 100644 --- a/internal/db/optional_driver_agent_impl_test.go +++ b/internal/db/optional_driver_agent_impl_test.go @@ -1,32 +1,67 @@ package db import ( - "context" "testing" - "time" + + "GoNavi-Wails/internal/connection" ) -func TestTimeoutMsFromContext_NoDeadline(t *testing.T) { - if got := timeoutMsFromContext(context.Background()); got != 0 { - t.Fatalf("无 deadline 时应返回 0,got=%d", got) +func TestNormalizeKingbaseAgentTableName(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server.andon_events", want: "ldf_server.andon_events"}, + {name: "quoted", in: `"ldf_server"."andon_events"`, want: "ldf_server.andon_events"}, + {name: "double quoted", in: `""ldf_server"".""andon_events""`, want: "ldf_server.andon_events"}, + {name: "escaped", in: `\"ldf_server\".\"andon_events\"`, want: "ldf_server.andon_events"}, + {name: "double escaped", in: `\\\"ldf_server\\\".\\\"andon_events\\\"`, want: "ldf_server.andon_events"}, + {name: "space around dot", in: ` "ldf_server" . "andon_events" `, want: "ldf_server.andon_events"}, + {name: "table only", in: `bcs_barcode`, want: "bcs_barcode"}, + {name: "table only quoted", in: `"bcs_barcode"`, want: "bcs_barcode"}, + {name: "table only double quoted", in: `""bcs_barcode""`, want: "bcs_barcode"}, + {name: "table only double escaped", in: `\\\"bcs_barcode\\\"`, want: "bcs_barcode"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeKingbaseAgentTableName(tt.in); got != tt.want { + t.Fatalf("normalizeKingbaseAgentTableName(%q) = %q, want %q", tt.in, got, tt.want) + } + }) } } -func TestTimeoutMsFromContext_WithDeadline(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() +func TestNormalizeKingbaseAgentChangeSetByColumns(t *testing.T) { + columns := []string{"andon_events_id", "event_name", "event_code"} + input := connection.ChangeSet{ + Inserts: []map[string]interface{}{ + {"event name": "物料1", "event_code": "EV-0001", "andon_events_id": 1}, + }, + Updates: []connection.UpdateRow{ + {Keys: map[string]interface{}{"andon_events_id": 1}, Values: map[string]interface{}{"event name": "物料2"}}, + }, + Deletes: []map[string]interface{}{ + {"andon_events_id": 1}, + }, + } - got := timeoutMsFromContext(ctx) - if got <= 0 { - t.Fatalf("有 deadline 时应返回正值,got=%d", got) - } -} - -func TestTimeoutMsFromContext_ExpiredDeadline(t *testing.T) { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) - defer cancel() - - if got := timeoutMsFromContext(ctx); got != 1 { - t.Fatalf("过期 deadline 应返回 1,got=%d", got) + out, err := normalizeKingbaseAgentChangeSetByColumns(input, columns) + if err != nil { + t.Fatalf("normalizeKingbaseAgentChangeSetByColumns error: %v", err) + } + + if _, ok := out.Inserts[0]["event_name"]; !ok { + t.Fatalf("expected insert to map \"event name\" -> \"event_name\"") + } + if _, ok := out.Inserts[0]["event name"]; ok { + t.Fatalf("unexpected insert key \"event name\" after normalization") + } + if _, ok := out.Updates[0].Values["event_name"]; !ok { + t.Fatalf("expected update values to map \"event name\" -> \"event_name\"") + } + if _, ok := out.Updates[0].Values["event name"]; ok { + t.Fatalf("unexpected update value key \"event name\" after normalization") } } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index e224608..56cf583 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -14,8 +14,9 @@ import ( ) const ( - envLogDir = "GONAVI_LOG_DIR" - appDirName = "GoNavi" + envLogDir = "GONAVI_LOG_DIR" + appHiddenDir = ".GoNavi" + appLogDirName = "Logs" logFileName = "gonavi.log" logRotateMaxBytes = 10 * 1024 * 1024 // 10MB @@ -37,7 +38,7 @@ func Init() { defer logMu.Unlock() logPath = path logInst = log.New(out, "", log.Ldate|log.Ltime|log.Lmicroseconds) - logInst.Printf("[信息] 日志初始化完成,日志文件:%s", logPath) + logInst.Printf("[INFO] 日志初始化完成,日志文件:%s", logPath) }) } @@ -62,15 +63,15 @@ func Close() { } func Infof(format string, args ...any) { - printf("信息", format, args...) + printf("INFO", format, args...) } func Warnf(format string, args ...any) { - printf("警告", format, args...) + printf("WARN", format, args...) } func Errorf(format string, args ...any) { - printf("错误", format, args...) + printf("ERROR", format, args...) } func Error(err error, format string, args ...any) { @@ -115,37 +116,58 @@ func ErrorChain(err error) string { func printf(level string, format string, args ...any) { Init() logMu.Lock() + defer logMu.Unlock() inst := logInst - logMu.Unlock() if inst == nil { return } inst.Printf("[%s] %s", level, fmt.Sprintf(format, args...)) + if logFile != nil { + _ = logFile.Sync() + } } func initOutput() (string, io.Writer) { dir := strings.TrimSpace(os.Getenv(envLogDir)) if dir == "" { - base, err := os.UserConfigDir() - if err != nil || strings.TrimSpace(base) == "" { - base = os.TempDir() - } - dir = filepath.Join(base, appDirName, "logs") + dir = defaultLogDir() } + if path, writer, ok := openLogFile(dir); ok { + return path, writer + } + + fallbackDir := filepath.Join(os.TempDir(), appHiddenDir, appLogDirName) + if path, writer, ok := openLogFile(fallbackDir); ok { + return path, writer + } + + return "", os.Stderr +} + +func defaultLogDir() string { + home, err := os.UserHomeDir() + if err != nil || strings.TrimSpace(home) == "" { + return filepath.Join(os.TempDir(), appHiddenDir, appLogDirName) + } + return filepath.Join(home, appHiddenDir, appLogDirName) +} + +func openLogFile(dir string) (string, io.Writer, bool) { + if strings.TrimSpace(dir) == "" { + return "", nil, false + } if err := os.MkdirAll(dir, 0o755); err != nil { - return filepath.Join(dir, logFileName), os.Stderr + return "", nil, false } - path := filepath.Join(dir, logFileName) rotateIfNeeded(path, dir) - f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { - return path, os.Stderr + return "", nil, false } logFile = f - return path, f + return path, f, true } func rotateIfNeeded(path, dir string) { diff --git a/internal/redis/redis_impl.go b/internal/redis/redis_impl.go index 8d41a28..93db691 100644 --- a/internal/redis/redis_impl.go +++ b/internal/redis/redis_impl.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "net" + "net/url" "strconv" "strings" "sync" @@ -174,8 +175,31 @@ func (r *RedisClientImpl) toDisplayKey(key string) string { return strings.TrimPrefix(key, prefix) } +// sanitizeRedisPassword 对 Redis 密码进行防御性 URL 解码。 +// 当密码中包含 URL 编码序列(如 %40)时,尝试解码还原原始字符。 +// 这可以防止前端 URI 构建中 encodeURIComponent 编码后的密码被误传入。 +func sanitizeRedisPassword(password string) string { + if password == "" { + return password + } + // 仅当密码中包含 '%' 且后跟两位十六进制数字时,才尝试 URL 解码 + if !strings.Contains(password, "%") { + return password + } + decoded, err := url.QueryUnescape(password) + if err != nil { + // 解码失败,使用原始密码 + return password + } + if decoded != password { + logger.Warnf("Redis 密码检测到 URL 编码,已自动解码(原长度=%d 解码后长度=%d)", len(password), len(decoded)) + } + return decoded +} + // Connect establishes a connection to Redis func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error { + config.Password = sanitizeRedisPassword(config.Password) r.config = config if r.config.RedisDB < 0 || r.config.RedisDB > 15 { r.config.RedisDB = 0 diff --git a/internal/redis/redis_impl_test.go b/internal/redis/redis_impl_test.go new file mode 100644 index 0000000..7014ab8 --- /dev/null +++ b/internal/redis/redis_impl_test.go @@ -0,0 +1,81 @@ +package redis + +import "testing" + +func TestSanitizeRedisPassword(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty password", + input: "", + expected: "", + }, + { + name: "plain password without special chars", + input: "mypassword123", + expected: "mypassword123", + }, + { + name: "password with @ not encoded", + input: "p@ssword", + expected: "p@ssword", + }, + { + name: "password with @ URL-encoded as %40", + input: "p%40ssword", + expected: "p@ssword", + }, + { + name: "password with multiple encoded chars", + input: "p%40ss%23word", + expected: "p@ss#word", + }, + { + name: "password with + encoded as %2B", + input: "p%2Bss", + expected: "p+ss", + }, + { + name: "password that is purely encoded", + input: "%40%23%24", + expected: "@#$", + }, + { + name: "password with invalid percent encoding", + input: "p%ZZssword", + expected: "p%ZZssword", + }, + { + name: "password with trailing percent", + input: "password%", + expected: "password%", + }, + { + name: "password with literal percent not encoding anything", + input: "100%safe", + expected: "100%safe", + }, + { + name: "password with space encoded as %20", + input: "my%20pass", + expected: "my pass", + }, + { + name: "complex password with mixed content", + input: "P%40ss%23w0rd!", + expected: "P@ss#w0rd!", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sanitizeRedisPassword(tt.input) + if result != tt.expected { + t.Errorf("sanitizeRedisPassword(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 51ad364..15feab3 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -2,10 +2,13 @@ package ssh import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "io" "net" "os" + "strconv" "sync" "time" @@ -69,7 +72,7 @@ func connectSSH(config connection.SSHConfig) (*ssh.Client, error) { } } } - + if config.Password != "" { authMethods = append(authMethods, ssh.Password(config.Password)) } @@ -105,7 +108,7 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) { // Generate unique network name netName := fmt.Sprintf("ssh_%s_%d", sshConfig.Host, time.Now().UnixNano()) logger.Infof("注册 SSH 网络:%s(地址=%s:%d 用户=%s)", netName, sshConfig.Host, sshConfig.Port, sshConfig.User) - + mysql.RegisterDialContext(netName, func(ctx context.Context, addr string) (net.Conn, error) { return dialContext(ctx, client, "tcp", addr) }) @@ -115,12 +118,58 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) { // sshClientCache stores SSH clients to avoid creating multiple connections var ( - sshClientCache = make(map[string]*ssh.Client) + sshClientCache = make(map[sshClientCacheKey]*ssh.Client) sshClientCacheMu sync.RWMutex - localForwarders = make(map[string]*LocalForwarder) + localForwarders = make(map[forwarderCacheKey]*LocalForwarder) forwarderMu sync.RWMutex ) +type sshClientCacheKey struct { + host string + port int + user string + auth string +} + +type forwarderCacheKey struct { + ssh sshClientCacheKey + remoteHost string + remotePort int +} + +func sshAuthFingerprint(config connection.SSHConfig) string { + hasher := sha256.New() + _, _ = hasher.Write([]byte(config.Password)) + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte(config.KeyPath)) + if config.KeyPath != "" { + if st, err := os.Stat(config.KeyPath); err == nil { + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte(st.ModTime().UTC().Format(time.RFC3339Nano))) + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte(strconv.FormatInt(st.Size(), 10))) + } else { + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte("stat_err")) + } + } + sum := hasher.Sum(nil) + return hex.EncodeToString(sum[:8]) +} + +func newSSHClientCacheKey(config connection.SSHConfig) sshClientCacheKey { + return sshClientCacheKey{ + host: config.Host, + port: config.Port, + user: config.User, + auth: sshAuthFingerprint(config), + } +} + +func formatSSHClientKeyForLog(key sshClientCacheKey) string { + return fmt.Sprintf("%s:%d 用户=%s", key.host, key.port, key.user) +} + // LocalForwarder represents a local port forwarder through SSH type LocalForwarder struct { LocalAddr string @@ -249,9 +298,13 @@ func (f *LocalForwarder) IsClosed() bool { // GetOrCreateLocalForwarder returns a cached forwarder or creates a new one func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) { - key := fmt.Sprintf("%s:%d:%s->%s:%d", - sshConfig.Host, sshConfig.Port, sshConfig.User, - remoteHost, remotePort) + key := forwarderCacheKey{ + ssh: newSSHClientCacheKey(sshConfig), + remoteHost: remoteHost, + remotePort: remotePort, + } + logKey := fmt.Sprintf("%s:%d:%s->%s:%d", + sshConfig.Host, sshConfig.Port, sshConfig.User, remoteHost, remotePort) forwarderMu.RLock() forwarder, exists := localForwarders[key] @@ -259,7 +312,7 @@ func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string // Check if exists and is still valid if exists && forwarder != nil && !forwarder.IsClosed() { - logger.Infof("复用已有端口转发:%s", key) + logger.Infof("复用已有端口转发:%s", logKey) return forwarder, nil } @@ -287,24 +340,18 @@ func CloseAllForwarders() { forwarderMu.Lock() defer forwarderMu.Unlock() - for key, forwarder := range localForwarders { + for _, forwarder := range localForwarders { if forwarder != nil { _ = forwarder.Close() - logger.Infof("已关闭端口转发:%s", key) + logger.Infof("已关闭端口转发:本地 %s -> 远程 %s", forwarder.LocalAddr, forwarder.RemoteAddr) } } - localForwarders = make(map[string]*LocalForwarder) -} - - -// getSSHClientCacheKey generates a unique cache key for SSH config -func getSSHClientCacheKey(config connection.SSHConfig) string { - return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User) + localForwarders = make(map[forwarderCacheKey]*LocalForwarder) } // GetOrCreateSSHClient returns a cached SSH client or creates a new one func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) { - key := getSSHClientCacheKey(config) + key := newSSHClientCacheKey(config) sshClientCacheMu.RLock() client, exists := sshClientCache[key] @@ -315,11 +362,11 @@ func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) { session, err := client.NewSession() if err == nil { session.Close() - logger.Infof("复用已有 SSH 连接:%s", key) + logger.Infof("复用已有 SSH 连接:%s", formatSSHClientKeyForLog(key)) return client, nil } // Connection is dead, remove from cache - logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err) + logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", formatSSHClientKeyForLog(key), err) sshClientCacheMu.Lock() delete(sshClientCache, key) sshClientCacheMu.Unlock() @@ -338,7 +385,7 @@ func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) { sshClientCache[key] = client sshClientCacheMu.Unlock() - logger.Infof("已缓存 SSH 连接:%s", key) + logger.Infof("已缓存 SSH 连接:%s", formatSSHClientKeyForLog(key)) return client, nil } @@ -367,9 +414,8 @@ func CloseAllSSHClients() { for key, client := range sshClientCache { if client != nil { _ = client.Close() - logger.Infof("已关闭 SSH 连接:%s", key) + logger.Infof("已关闭 SSH 连接:%s", formatSSHClientKeyForLog(key)) } } - sshClientCache = make(map[string]*ssh.Client) + sshClientCache = make(map[sshClientCacheKey]*ssh.Client) } - diff --git a/internal/ssh/ssh_cache_key_test.go b/internal/ssh/ssh_cache_key_test.go new file mode 100644 index 0000000..16322c2 --- /dev/null +++ b/internal/ssh/ssh_cache_key_test.go @@ -0,0 +1,46 @@ +package ssh + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestNewSSHClientCacheKey_DiffPassword(t *testing.T) { + a := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + Password: "a", + }) + b := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + Password: "b", + }) + if a == b { + t.Fatalf("expected different cache key when password differs") + } + if a.host != b.host || a.port != b.port || a.user != b.user { + t.Fatalf("expected host/port/user to stay identical") + } +} + +func TestNewSSHClientCacheKey_DiffKeyPath(t *testing.T) { + a := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + KeyPath: "/tmp/a.key", + }) + b := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + KeyPath: "/tmp/b.key", + }) + if a == b { + t.Fatalf("expected different cache key when keyPath differs") + } +}