mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-12 12:19:47 +08:00
Compare commits
85 Commits
release/0.
...
v0.5.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfabd77615 | ||
|
|
76f65cb96c | ||
|
|
4a2dda8aa2 | ||
|
|
8bdc6e8086 | ||
|
|
1eb2f6dffe | ||
|
|
5c5e1fc68f | ||
|
|
fb70f1420c | ||
|
|
d75596921c | ||
|
|
d251594fd9 | ||
|
|
7598bf372b | ||
|
|
64021ffd2a | ||
|
|
fbd785400f | ||
|
|
b573fd95cc | ||
|
|
a097d96380 | ||
|
|
6ee0fea110 | ||
|
|
e6b822c967 | ||
|
|
0ab10d2e80 | ||
|
|
064cdc34be | ||
|
|
c62f4b7d3c | ||
|
|
304a4926d2 | ||
|
|
d1d3fa26f1 | ||
|
|
cabf84a041 | ||
|
|
fc8e62b997 | ||
|
|
9b02720169 | ||
|
|
eb36dcc5a2 | ||
|
|
1a3f137438 | ||
|
|
5f94cd3911 | ||
|
|
bb257c35bc | ||
|
|
b0eb93bfa3 | ||
|
|
11b8e0f12a | ||
|
|
1dabac1a65 | ||
|
|
e013288967 | ||
|
|
8c5fee1c7a | ||
|
|
ec05f518a9 | ||
|
|
2c9aa640fd | ||
|
|
d467322ebe | ||
|
|
9f7cc58fad | ||
|
|
97bf891df3 | ||
|
|
72a9692200 | ||
|
|
e26a456eae | ||
|
|
eaa45f17fd | ||
|
|
f101a59d32 | ||
|
|
501ad9e9a3 | ||
|
|
482a7fce2e | ||
|
|
e6af5f966b | ||
|
|
eef973b7fc | ||
|
|
d8b6b4ef8d | ||
|
|
4d58cc6e26 | ||
|
|
b0bdddad9b | ||
|
|
a73ca36a32 | ||
|
|
92e9381fcc | ||
|
|
c4c7e379d1 | ||
|
|
695713c779 | ||
|
|
6ad690cffc | ||
|
|
ca49b37dc7 | ||
|
|
c8c0c5f20a | ||
|
|
d61d7ec39b | ||
|
|
e964c8ecf8 | ||
|
|
7644462180 | ||
|
|
3bd02e2e09 | ||
|
|
22bd1c4c28 | ||
|
|
0daf702d25 | ||
|
|
058c74e49a | ||
|
|
b85c7529ec | ||
|
|
e521d2125f | ||
|
|
450fdfa59e | ||
|
|
c87b15b22a | ||
|
|
89c81823bc | ||
|
|
797ba27d20 | ||
|
|
ed1f40e04a | ||
|
|
2b190e564f | ||
|
|
1c050aefd0 | ||
|
|
75a5a322e0 | ||
|
|
61d6197fe3 | ||
|
|
6157161293 | ||
|
|
0f843a7dcf | ||
|
|
fb65b553e9 | ||
|
|
1a5bf79dd3 | ||
|
|
dea096d4c2 | ||
|
|
04f8b266d3 | ||
|
|
b53227cb15 | ||
|
|
0246d7fae5 | ||
|
|
4aa177ed37 | ||
|
|
4f5a7bd94b | ||
|
|
00c6f9871f |
26
.github/release.yaml
vendored
Normal file
26
.github/release.yaml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
changelog:
|
||||
categories:
|
||||
- title: 新功能
|
||||
labels:
|
||||
- feature
|
||||
- enhancement
|
||||
- feat
|
||||
- title: 问题修复
|
||||
labels:
|
||||
- bug
|
||||
- fix
|
||||
- title: 文档与流程
|
||||
labels:
|
||||
- docs
|
||||
- documentation
|
||||
- ci
|
||||
- workflow
|
||||
- chore
|
||||
- title: 重构与优化
|
||||
labels:
|
||||
- refactor
|
||||
- perf
|
||||
- optimization
|
||||
- title: 其他更新
|
||||
labels:
|
||||
- '*'
|
||||
3
.github/workflows/release-winget.yml
vendored
3
.github/workflows/release-winget.yml
vendored
@@ -10,6 +10,9 @@ on:
|
||||
description: 'Tag of release you want to publish'
|
||||
type: string
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: "true"
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: windows-latest
|
||||
|
||||
85
.github/workflows/release.yml
vendored
85
.github/workflows/release.yml
vendored
@@ -8,6 +8,9 @@ on:
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: "true"
|
||||
|
||||
jobs:
|
||||
# Phase 1: Build in parallel and output artifacts
|
||||
build:
|
||||
@@ -88,6 +91,26 @@ jobs:
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install UPX (Windows)
|
||||
if: contains(matrix.platform, 'windows')
|
||||
shell: pwsh
|
||||
run: |
|
||||
$UPX_VERSION = "4.2.4"
|
||||
$url = "https://github.com/upx/upx/releases/download/v${UPX_VERSION}/upx-${UPX_VERSION}-win64.zip"
|
||||
$zipPath = "$env:RUNNER_TEMP\upx.zip"
|
||||
$extractPath = "$env:RUNNER_TEMP\upx"
|
||||
Write-Host "📥 从 GitHub Releases 下载 UPX v${UPX_VERSION} ..."
|
||||
Invoke-WebRequest -Uri $url -OutFile $zipPath -UseBasicParsing
|
||||
Expand-Archive -Path $zipPath -DestinationPath $extractPath -Force
|
||||
$upxDir = Get-ChildItem -Path $extractPath -Directory | Select-Object -First 1
|
||||
"$($upxDir.FullName)" | Out-File -FilePath $env:GITHUB_PATH -Append -Encoding utf8
|
||||
$upxCmd = Join-Path $upxDir.FullName "upx.exe"
|
||||
if (!(Test-Path $upxCmd)) {
|
||||
Write-Error "❌ 未检测到 upx,无法保证 Windows 产物经过压缩"
|
||||
exit 1
|
||||
}
|
||||
& $upxCmd --version
|
||||
|
||||
# Linux Dependencies (GTK3, WebKit2GTK required by Wails)
|
||||
- name: Install Linux Dependencies
|
||||
if: contains(matrix.platform, 'linux')
|
||||
@@ -102,6 +125,9 @@ jobs:
|
||||
sudo apt-get install -y libwebkit2gtk-4.0-dev
|
||||
fi
|
||||
|
||||
sudo apt-get install -y upx-ucl || sudo apt-get install -y upx
|
||||
upx --version
|
||||
|
||||
# AppImage 运行/打包可能需要 FUSE2。不同发行版/版本包名不同,做兼容兜底。
|
||||
sudo apt-get install -y libfuse2 || sudo apt-get install -y libfuse2t64 || true
|
||||
|
||||
@@ -277,6 +303,13 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
APP_NAME=$(basename "$APP_PATH")
|
||||
|
||||
APP_BIN=$(find "$APP_PATH/Contents/MacOS" -maxdepth 1 -type f | head -n 1)
|
||||
if [ -z "$APP_BIN" ]; then
|
||||
echo "❌ 未找到 macOS 应用主程序!"
|
||||
exit 1
|
||||
fi
|
||||
echo "ℹ️ macOS 产物不执行 UPX 压缩,保留原始主程序。"
|
||||
|
||||
echo "🔏 正在进行 Ad-hoc 签名..."
|
||||
# 注意:Ad-hoc + hardened runtime(--options runtime)在未配置 entitlements 时,
|
||||
@@ -301,7 +334,7 @@ jobs:
|
||||
mv "$DMG_NAME" "../../$FINAL_NAME"
|
||||
|
||||
# Windows Packaging
|
||||
- name: Package Windows Portable Zip
|
||||
- name: Package Windows EXE
|
||||
if: contains(matrix.platform, 'windows')
|
||||
shell: pwsh
|
||||
run: |
|
||||
@@ -312,7 +345,6 @@ jobs:
|
||||
}
|
||||
$target = "${{ matrix.build_name }}"
|
||||
$finalExeName = "GoNavi-$version-${{ matrix.os_name }}-${{ matrix.arch_name }}${{ matrix.artifact_suffix }}.exe"
|
||||
$finalZipName = "GoNavi-$version-${{ matrix.os_name }}-${{ matrix.arch_name }}${{ matrix.artifact_suffix }}.zip"
|
||||
|
||||
if (Test-Path "$target.exe") {
|
||||
$finalExe = "$target.exe"
|
||||
@@ -324,11 +356,39 @@ jobs:
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host "📦 生成 Windows 可执行文件 $finalExeName..."
|
||||
Copy-Item -LiteralPath $finalExe -Destination "..\\..\\$finalExeName" -Force
|
||||
$isArm64Target = "${{ matrix.arch_name }}".ToLowerInvariant() -eq "arm64"
|
||||
if ($isArm64Target) {
|
||||
Write-Warning "⚠️ UPX 当前不支持 win64/arm64,跳过压缩并保留原始 EXE。"
|
||||
$LASTEXITCODE = 0
|
||||
} else {
|
||||
$upxCmd = Get-Command upx -ErrorAction SilentlyContinue
|
||||
if ($null -eq $upxCmd) {
|
||||
Write-Error "❌ 未找到 upx,无法保证 Windows 产物经过压缩"
|
||||
exit 1
|
||||
}
|
||||
$beforeBytes = (Get-Item -LiteralPath $finalExe).Length
|
||||
Write-Host "🗜️ 使用 UPX 压缩 $finalExe ..."
|
||||
& upx --best --lzma --force $finalExe | Out-Host
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "❌ UPX 压缩失败($LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
& upx -t $finalExe | Out-Host
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "❌ UPX 校验失败($LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
$afterBytes = (Get-Item -LiteralPath $finalExe).Length
|
||||
if ($afterBytes -lt $beforeBytes) {
|
||||
$savedBytes = $beforeBytes - $afterBytes
|
||||
Write-Host ("✅ UPX 压缩完成:{0:N2}MB -> {1:N2}MB,减少 {2:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB), ($savedBytes / 1MB))
|
||||
} else {
|
||||
Write-Host ("ℹ️ UPX 压缩完成:{0:N2}MB -> {1:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB))
|
||||
}
|
||||
}
|
||||
|
||||
Write-Host "📦 生成 Windows 压缩包 $finalZipName..."
|
||||
Compress-Archive -LiteralPath $finalExe -DestinationPath "..\\..\\$finalZipName" -Force
|
||||
Write-Host "📦 输出 Windows 可执行文件 $finalExeName..."
|
||||
Copy-Item -LiteralPath $finalExe -Destination "..\\..\\$finalExeName" -Force
|
||||
|
||||
# Linux Packaging (tar.gz and AppImage)
|
||||
- name: Package Linux
|
||||
@@ -347,6 +407,17 @@ jobs:
|
||||
fi
|
||||
|
||||
chmod +x "$TARGET"
|
||||
BEFORE_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]')
|
||||
echo "🗜️ 正在使用 UPX 压缩 Linux 可执行文件: $TARGET ..."
|
||||
upx --best --lzma --force "$TARGET"
|
||||
upx -t "$TARGET"
|
||||
AFTER_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]')
|
||||
if [ "$AFTER_BYTES" -lt "$BEFORE_BYTES" ]; then
|
||||
SAVED_BYTES=$((BEFORE_BYTES - AFTER_BYTES))
|
||||
awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" -v s="$SAVED_BYTES" 'BEGIN { printf "✅ Linux UPX 压缩完成:%.2fMB -> %.2fMB,减少 %.2fMB\n", b/1024/1024, a/1024/1024, s/1024/1024 }'
|
||||
else
|
||||
awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" 'BEGIN { printf "ℹ️ Linux UPX 压缩完成:%.2fMB -> %.2fMB\n", b/1024/1024, a/1024/1024 }'
|
||||
fi
|
||||
|
||||
# 1. Create tar.gz
|
||||
echo "📦 正在打包 $TAR_NAME..."
|
||||
@@ -419,7 +490,6 @@ jobs:
|
||||
path: |
|
||||
GoNavi-*.dmg
|
||||
GoNavi-*.exe
|
||||
GoNavi-*.zip
|
||||
GoNavi-*.tar.gz
|
||||
GoNavi-*.AppImage
|
||||
drivers/**
|
||||
@@ -550,5 +620,6 @@ jobs:
|
||||
files: release-assets/*
|
||||
draft: true
|
||||
make_latest: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
412
.github/workflows/test-build-all-platforms.yml
vendored
Normal file
412
.github/workflows/test-build-all-platforms.yml
vendored
Normal file
@@ -0,0 +1,412 @@
|
||||
name: Test Build All Platforms (Manual)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
build_label:
|
||||
description: "测试包标识(仅用于文件名)"
|
||||
required: false
|
||||
default: "test"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: "true"
|
||||
|
||||
concurrency:
|
||||
group: test-build-${{ github.ref }}
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build ${{ matrix.platform }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: macos-latest
|
||||
platform: darwin/amd64
|
||||
os_name: MacOS
|
||||
arch_name: Amd64
|
||||
build_name: gonavi-test-darwin-amd64
|
||||
wails_tags: ""
|
||||
artifact_suffix: ""
|
||||
build_optional_agents: true
|
||||
linux_webkit: ""
|
||||
- os: macos-latest
|
||||
platform: darwin/arm64
|
||||
os_name: MacOS
|
||||
arch_name: Arm64
|
||||
build_name: gonavi-test-darwin-arm64
|
||||
wails_tags: ""
|
||||
artifact_suffix: ""
|
||||
build_optional_agents: true
|
||||
linux_webkit: ""
|
||||
- os: windows-latest
|
||||
platform: windows/amd64
|
||||
os_name: Windows
|
||||
arch_name: Amd64
|
||||
build_name: gonavi-test-windows-amd64
|
||||
wails_tags: ""
|
||||
artifact_suffix: ""
|
||||
build_optional_agents: true
|
||||
linux_webkit: ""
|
||||
- os: windows-latest
|
||||
platform: windows/arm64
|
||||
os_name: Windows
|
||||
arch_name: Arm64
|
||||
build_name: gonavi-test-windows-arm64
|
||||
wails_tags: ""
|
||||
artifact_suffix: ""
|
||||
build_optional_agents: true
|
||||
linux_webkit: ""
|
||||
- os: ubuntu-22.04
|
||||
platform: linux/amd64
|
||||
os_name: Linux
|
||||
arch_name: Amd64
|
||||
build_name: gonavi-test-linux-amd64
|
||||
wails_tags: ""
|
||||
artifact_suffix: ""
|
||||
build_optional_agents: true
|
||||
linux_webkit: "4.0"
|
||||
- os: ubuntu-24.04
|
||||
platform: linux/amd64
|
||||
os_name: Linux
|
||||
arch_name: Amd64
|
||||
build_name: gonavi-test-linux-amd64-webkit41
|
||||
wails_tags: "webkit2_41"
|
||||
artifact_suffix: "-WebKit41"
|
||||
build_optional_agents: false
|
||||
linux_webkit: "4.1"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24'
|
||||
check-latest: true
|
||||
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install UPX (Windows)
|
||||
if: contains(matrix.platform, 'windows')
|
||||
shell: pwsh
|
||||
run: |
|
||||
$UPX_VERSION = "4.2.4"
|
||||
$url = "https://github.com/upx/upx/releases/download/v${UPX_VERSION}/upx-${UPX_VERSION}-win64.zip"
|
||||
$zipPath = "$env:RUNNER_TEMP\upx.zip"
|
||||
$extractPath = "$env:RUNNER_TEMP\upx"
|
||||
Write-Host "📥 从 GitHub Releases 下载 UPX v${UPX_VERSION} ..."
|
||||
Invoke-WebRequest -Uri $url -OutFile $zipPath -UseBasicParsing
|
||||
Expand-Archive -Path $zipPath -DestinationPath $extractPath -Force
|
||||
$upxDir = Get-ChildItem -Path $extractPath -Directory | Select-Object -First 1
|
||||
"$($upxDir.FullName)" | Out-File -FilePath $env:GITHUB_PATH -Append -Encoding utf8
|
||||
$upxCmd = Join-Path $upxDir.FullName "upx.exe"
|
||||
if (!(Test-Path $upxCmd)) {
|
||||
Write-Error "❌ 未检测到 upx,无法保证 Windows 测试产物经过压缩"
|
||||
exit 1
|
||||
}
|
||||
& $upxCmd --version
|
||||
|
||||
- name: Install Linux Dependencies
|
||||
if: contains(matrix.platform, 'linux')
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libgtk-3-dev
|
||||
|
||||
if [ "${{ matrix.linux_webkit }}" = "4.1" ]; then
|
||||
sudo apt-get install -y libwebkit2gtk-4.1-dev libsoup-3.0-dev
|
||||
else
|
||||
sudo apt-get install -y libwebkit2gtk-4.0-dev
|
||||
fi
|
||||
|
||||
sudo apt-get install -y upx-ucl || sudo apt-get install -y upx
|
||||
upx --version
|
||||
|
||||
sudo apt-get install -y libfuse2 || sudo apt-get install -y libfuse2t64 || true
|
||||
|
||||
LINUXDEPLOY_URL="https://github.com/linuxdeploy/linuxdeploy/releases/download/continuous/linuxdeploy-x86_64.AppImage"
|
||||
PLUGIN_URL="https://github.com/linuxdeploy/linuxdeploy-plugin-gtk/releases/download/continuous/linuxdeploy-plugin-gtk-x86_64.AppImage"
|
||||
|
||||
wget --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 --tries=3 -O /tmp/linuxdeploy "$LINUXDEPLOY_URL" || {
|
||||
echo "skip-appimage=true" >> "$GITHUB_ENV"
|
||||
}
|
||||
wget --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 --tries=3 -O /tmp/linuxdeploy-plugin-gtk "$PLUGIN_URL" || {
|
||||
echo "skip-appimage=true" >> "$GITHUB_ENV"
|
||||
}
|
||||
|
||||
if [ "${skip-appimage:-false}" != "true" ]; then
|
||||
chmod +x /tmp/linuxdeploy /tmp/linuxdeploy-plugin-gtk
|
||||
fi
|
||||
|
||||
- name: Install Wails
|
||||
run: go install github.com/wailsapp/wails/v2/cmd/wails@v2.11.0
|
||||
|
||||
- name: Setup MSYS2 Toolchain For DuckDB (Windows AMD64)
|
||||
id: msys2_duckdb
|
||||
if: ${{ matrix.build_optional_agents && matrix.platform == 'windows/amd64' }}
|
||||
continue-on-error: true
|
||||
uses: msys2/setup-msys2@v2
|
||||
with:
|
||||
msystem: UCRT64
|
||||
update: true
|
||||
install: >-
|
||||
mingw-w64-ucrt-x86_64-gcc
|
||||
|
||||
- name: Configure DuckDB CGO Toolchain (Windows AMD64)
|
||||
if: ${{ matrix.build_optional_agents && matrix.platform == 'windows/amd64' }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
function Find-MingwBin([string[]]$candidates) {
|
||||
foreach ($bin in $candidates) {
|
||||
if ([string]::IsNullOrWhiteSpace($bin)) {
|
||||
continue
|
||||
}
|
||||
$gcc = Join-Path $bin 'gcc.exe'
|
||||
$gxx = Join-Path $bin 'g++.exe'
|
||||
if ((Test-Path $gcc) -and (Test-Path $gxx)) {
|
||||
return $bin
|
||||
}
|
||||
}
|
||||
return $null
|
||||
}
|
||||
|
||||
$msys2Location = "${{ steps.msys2_duckdb.outputs['msys2-location'] }}"
|
||||
$candidateBins = @()
|
||||
if (-not [string]::IsNullOrWhiteSpace($msys2Location)) {
|
||||
$candidateBins += Join-Path $msys2Location 'ucrt64\bin'
|
||||
}
|
||||
$candidateBins += @(
|
||||
'C:\msys64\ucrt64\bin',
|
||||
'D:\a\_temp\msys64\ucrt64\bin'
|
||||
)
|
||||
$candidateBins = @($candidateBins | Select-Object -Unique)
|
||||
|
||||
$mingwBin = Find-MingwBin $candidateBins
|
||||
if (-not $mingwBin) {
|
||||
Write-Error "❌ 未找到可用的 DuckDB UCRT64 编译器。"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$gcc = Join-Path $mingwBin 'gcc.exe'
|
||||
$gxx = Join-Path $mingwBin 'g++.exe'
|
||||
"$mingwBin" | Out-File -FilePath $env:GITHUB_PATH -Append -Encoding utf8
|
||||
"CC=$gcc" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
"CXX=$gxx" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
|
||||
- name: Build App
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
BUILD_LABEL="${{ inputs.build_label }}"
|
||||
if [ -z "$BUILD_LABEL" ]; then
|
||||
BUILD_LABEL="test"
|
||||
fi
|
||||
APP_VERSION="${BUILD_LABEL}-${GITHUB_RUN_NUMBER}"
|
||||
if [ -n "${{ matrix.wails_tags }}" ]; then
|
||||
wails build -platform "${{ matrix.platform }}" -clean -o "${{ matrix.build_name }}" -tags "${{ matrix.wails_tags }}" -ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${APP_VERSION}"
|
||||
else
|
||||
wails build -platform "${{ matrix.platform }}" -clean -o "${{ matrix.build_name }}" -ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${APP_VERSION}"
|
||||
fi
|
||||
|
||||
- name: Build Optional Driver Agents
|
||||
if: ${{ matrix.build_optional_agents }}
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TARGET_PLATFORM="${{ matrix.platform }}"
|
||||
GOOS="${TARGET_PLATFORM%%/*}"
|
||||
GOARCH="${TARGET_PLATFORM##*/}"
|
||||
DRIVERS=(mariadb doris sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase mongodb tdengine clickhouse)
|
||||
OUTDIR="drivers/${{ matrix.os_name }}"
|
||||
mkdir -p "$OUTDIR"
|
||||
|
||||
for DRIVER in "${DRIVERS[@]}"; do
|
||||
BUILD_DRIVER="$DRIVER"
|
||||
if [ "$DRIVER" = "doris" ]; then
|
||||
BUILD_DRIVER="diros"
|
||||
fi
|
||||
if [ "$DRIVER" = "duckdb" ] && [ "$GOOS" = "windows" ] && [ "$GOARCH" != "amd64" ]; then
|
||||
echo "跳过 DuckDB driver: ${GOOS}/${GOARCH}"
|
||||
continue
|
||||
fi
|
||||
TAG="gonavi_${BUILD_DRIVER}_driver"
|
||||
OUTPUT="${DRIVER}-driver-agent-${GOOS}-${GOARCH}"
|
||||
if [ "$GOOS" = "windows" ]; then
|
||||
OUTPUT="${OUTPUT}.exe"
|
||||
fi
|
||||
OUTPUT_PATH="${OUTDIR}/${OUTPUT}"
|
||||
if [ "$DRIVER" = "duckdb" ]; then
|
||||
CGO_ENABLED=1 GOOS="$GOOS" GOARCH="$GOARCH" go build -tags "$TAG" -trimpath -ldflags "-s -w" -o "$OUTPUT_PATH" ./cmd/optional-driver-agent
|
||||
else
|
||||
CGO_ENABLED=0 GOOS="$GOOS" GOARCH="$GOARCH" go build -tags "$TAG" -trimpath -ldflags "-s -w" -o "$OUTPUT_PATH" ./cmd/optional-driver-agent
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Package macOS
|
||||
if: contains(matrix.platform, 'darwin')
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
brew install create-dmg
|
||||
LABEL="${{ inputs.build_label }}"
|
||||
if [ -z "$LABEL" ]; then
|
||||
LABEL="test"
|
||||
fi
|
||||
cd build/bin
|
||||
APP_PATH=$(find . -maxdepth 1 -name "*.app" | head -n 1)
|
||||
if [ -z "$APP_PATH" ]; then
|
||||
echo "未找到 .app 应用包"
|
||||
exit 1
|
||||
fi
|
||||
APP_NAME=$(basename "$APP_PATH")
|
||||
APP_BIN=$(find "$APP_PATH/Contents/MacOS" -maxdepth 1 -type f | head -n 1)
|
||||
if [ -z "$APP_BIN" ]; then
|
||||
echo "未找到 macOS 应用主程序"
|
||||
exit 1
|
||||
fi
|
||||
echo "ℹ️ macOS 产物不执行 UPX 压缩,保留原始主程序。"
|
||||
codesign --force --deep --sign - "$APP_NAME"
|
||||
ZIP_NAME="GoNavi-${LABEL}-${{ matrix.os_name }}-${{ matrix.arch_name }}-run${GITHUB_RUN_NUMBER}.zip"
|
||||
DMG_NAME="GoNavi-${LABEL}-${{ matrix.os_name }}-${{ matrix.arch_name }}-run${GITHUB_RUN_NUMBER}.dmg"
|
||||
mkdir -p ../../artifacts
|
||||
ditto -c -k --sequesterRsrc --keepParent "$APP_NAME" "../../artifacts/$ZIP_NAME"
|
||||
create-dmg \
|
||||
--volname "GoNavi Test Installer" \
|
||||
--window-pos 200 120 \
|
||||
--window-size 800 400 \
|
||||
--icon-size 100 \
|
||||
--icon "$APP_NAME" 200 190 \
|
||||
--hide-extension "$APP_NAME" \
|
||||
--app-drop-link 600 185 \
|
||||
"$DMG_NAME" \
|
||||
"$APP_NAME"
|
||||
mv "$DMG_NAME" "../../artifacts/$DMG_NAME"
|
||||
shasum -a 256 "../../artifacts/$ZIP_NAME" > "../../artifacts/$ZIP_NAME.sha256"
|
||||
shasum -a 256 "../../artifacts/$DMG_NAME" > "../../artifacts/$DMG_NAME.sha256"
|
||||
|
||||
- name: Package Windows
|
||||
if: contains(matrix.platform, 'windows')
|
||||
shell: pwsh
|
||||
run: |
|
||||
$label = "${{ inputs.build_label }}"
|
||||
if ([string]::IsNullOrWhiteSpace($label)) { $label = 'test' }
|
||||
Set-Location build/bin
|
||||
$target = "${{ matrix.build_name }}"
|
||||
$finalExeName = "GoNavi-$label-${{ matrix.os_name }}-${{ matrix.arch_name }}-run$env:GITHUB_RUN_NUMBER.exe"
|
||||
if (Test-Path "$target.exe") {
|
||||
$finalExe = "$target.exe"
|
||||
} elseif (Test-Path "$target") {
|
||||
Rename-Item -Path "$target" -NewName "$target.exe"
|
||||
$finalExe = "$target.exe"
|
||||
} else {
|
||||
Write-Error "未找到构建产物 '$target'"
|
||||
exit 1
|
||||
}
|
||||
$isArm64Target = "${{ matrix.arch_name }}".ToLowerInvariant() -eq "arm64"
|
||||
if ($isArm64Target) {
|
||||
Write-Warning "⚠️ UPX 当前不支持 win64/arm64,跳过压缩并保留原始 EXE。"
|
||||
$LASTEXITCODE = 0
|
||||
} else {
|
||||
$upxCmd = Get-Command upx -ErrorAction SilentlyContinue
|
||||
if ($null -eq $upxCmd) {
|
||||
Write-Error "❌ 未找到 upx,无法保证 Windows 测试产物经过压缩"
|
||||
exit 1
|
||||
}
|
||||
$beforeBytes = (Get-Item -LiteralPath $finalExe).Length
|
||||
Write-Host "🗜️ 使用 UPX 压缩 $finalExe ..."
|
||||
& upx --best --lzma --force $finalExe | Out-Host
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "❌ UPX 压缩失败($LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
& upx -t $finalExe | Out-Host
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "❌ UPX 校验失败($LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
$afterBytes = (Get-Item -LiteralPath $finalExe).Length
|
||||
if ($afterBytes -lt $beforeBytes) {
|
||||
$savedBytes = $beforeBytes - $afterBytes
|
||||
Write-Host ("✅ UPX 压缩完成:{0:N2}MB -> {1:N2}MB,减少 {2:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB), ($savedBytes / 1MB))
|
||||
} else {
|
||||
Write-Host ("ℹ️ UPX 压缩完成:{0:N2}MB -> {1:N2}MB" -f ($beforeBytes / 1MB), ($afterBytes / 1MB))
|
||||
}
|
||||
}
|
||||
New-Item -ItemType Directory -Force -Path ..\..\artifacts | Out-Null
|
||||
Copy-Item -LiteralPath $finalExe -Destination "..\..\artifacts\$finalExeName" -Force
|
||||
Get-FileHash "..\..\artifacts\$finalExeName" -Algorithm SHA256 | ForEach-Object { "{0} *{1}" -f $_.Hash.ToLower(), (Split-Path $_.Path -Leaf) } | Out-File "..\..\artifacts\$finalExeName.sha256" -Encoding ascii
|
||||
|
||||
- name: Package Linux
|
||||
if: contains(matrix.platform, 'linux')
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
LABEL="${{ inputs.build_label }}"
|
||||
if [ -z "$LABEL" ]; then
|
||||
LABEL="test"
|
||||
fi
|
||||
cd build/bin
|
||||
TARGET="${{ matrix.build_name }}"
|
||||
TAR_NAME="GoNavi-${LABEL}-${{ matrix.os_name }}-${{ matrix.arch_name }}${{ matrix.artifact_suffix }}-run${GITHUB_RUN_NUMBER}.tar.gz"
|
||||
APPIMAGE_NAME="GoNavi-${LABEL}-${{ matrix.os_name }}-${{ matrix.arch_name }}${{ matrix.artifact_suffix }}-run${GITHUB_RUN_NUMBER}.AppImage"
|
||||
mkdir -p ../../artifacts
|
||||
|
||||
if [ ! -f "$TARGET" ]; then
|
||||
echo "未找到构建产物 '$TARGET'"
|
||||
exit 1
|
||||
fi
|
||||
chmod +x "$TARGET"
|
||||
BEFORE_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]')
|
||||
echo "🗜️ 使用 UPX 压缩 Linux 可执行文件: $TARGET ..."
|
||||
upx --best --lzma --force "$TARGET"
|
||||
upx -t "$TARGET"
|
||||
AFTER_BYTES=$(wc -c <"$TARGET" | tr -d '[:space:]')
|
||||
if [ "$AFTER_BYTES" -lt "$BEFORE_BYTES" ]; then
|
||||
SAVED_BYTES=$((BEFORE_BYTES - AFTER_BYTES))
|
||||
awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" -v s="$SAVED_BYTES" 'BEGIN { printf "✅ Linux UPX 压缩完成:%.2fMB -> %.2fMB,减少 %.2fMB\n", b/1024/1024, a/1024/1024, s/1024/1024 }'
|
||||
else
|
||||
awk -v b="$BEFORE_BYTES" -v a="$AFTER_BYTES" 'BEGIN { printf "ℹ️ Linux UPX 压缩完成:%.2fMB -> %.2fMB\n", b/1024/1024, a/1024/1024 }'
|
||||
fi
|
||||
tar -czvf "../../artifacts/$TAR_NAME" "$TARGET"
|
||||
sha256sum "../../artifacts/$TAR_NAME" > "../../artifacts/$TAR_NAME.sha256"
|
||||
|
||||
if [ "${skip-appimage:-false}" = "true" ]; then
|
||||
echo "跳过 AppImage 打包"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
mkdir -p AppDir/usr/bin AppDir/usr/share/applications AppDir/usr/share/icons/hicolor/256x256/apps
|
||||
cp "$TARGET" AppDir/usr/bin/gonavi
|
||||
printf '%s\n' '[Desktop Entry]' 'Name=GoNavi' 'Exec=gonavi' 'Icon=gonavi' 'Type=Application' 'Categories=Development;Database;' 'Comment=Database Management Tool' > AppDir/usr/share/applications/gonavi.desktop
|
||||
cp AppDir/usr/share/applications/gonavi.desktop AppDir/gonavi.desktop
|
||||
if [ -f "../../build/appicon.png" ]; then
|
||||
cp "../../build/appicon.png" AppDir/usr/share/icons/hicolor/256x256/apps/gonavi.png
|
||||
cp "../../build/appicon.png" AppDir/gonavi.png
|
||||
else
|
||||
touch AppDir/gonavi.png
|
||||
cp AppDir/gonavi.png AppDir/usr/share/icons/hicolor/256x256/apps/gonavi.png
|
||||
fi
|
||||
export DEPLOY_GTK_VERSION=3
|
||||
/tmp/linuxdeploy --appdir AppDir --plugin gtk --output appimage || exit 0
|
||||
mv GoNavi*.AppImage "$APPIMAGE_NAME" 2>/dev/null || exit 0
|
||||
mv "$APPIMAGE_NAME" "../../artifacts/$APPIMAGE_NAME"
|
||||
sha256sum "../../artifacts/$APPIMAGE_NAME" > "../../artifacts/$APPIMAGE_NAME.sha256"
|
||||
|
||||
- name: Upload Artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-build-${{ matrix.build_name }}-run${{ github.run_number }}
|
||||
path: |
|
||||
artifacts/*
|
||||
drivers/**
|
||||
if-no-files-found: error
|
||||
retention-days: 7
|
||||
94
.github/workflows/test-macos-build.yml
vendored
Normal file
94
.github/workflows/test-macos-build.yml
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
name: Test Build macOS (Manual)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
build_label:
|
||||
description: "测试包标识(仅用于文件名)"
|
||||
required: false
|
||||
default: "test"
|
||||
push:
|
||||
branches:
|
||||
- feature/kingbase_opt
|
||||
paths:
|
||||
- ".github/workflows/test-macos-build.yml"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: "true"
|
||||
|
||||
jobs:
|
||||
build-macos:
|
||||
name: Build macOS ${{ matrix.arch }}
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- platform: darwin/amd64
|
||||
arch: amd64
|
||||
- platform: darwin/arm64
|
||||
arch: arm64
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.3"
|
||||
check-latest: true
|
||||
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
|
||||
- name: Install Wails
|
||||
run: go install github.com/wailsapp/wails/v2/cmd/wails@v2.11.0
|
||||
|
||||
- name: Build App
|
||||
run: |
|
||||
set -euo pipefail
|
||||
OUTPUT_NAME="gonavi-test-${{ matrix.arch }}"
|
||||
BUILD_LABEL="${{ inputs.build_label }}"
|
||||
if [ -z "$BUILD_LABEL" ]; then
|
||||
BUILD_LABEL="test"
|
||||
fi
|
||||
APP_VERSION="${BUILD_LABEL}-${GITHUB_RUN_NUMBER}"
|
||||
wails build \
|
||||
-platform "${{ matrix.platform }}" \
|
||||
-clean \
|
||||
-o "$OUTPUT_NAME" \
|
||||
-ldflags "-s -w -X GoNavi-Wails/internal/app.AppVersion=${APP_VERSION}"
|
||||
|
||||
- name: Package Zip
|
||||
run: |
|
||||
set -euo pipefail
|
||||
APP_PATH="build/bin/gonavi-test-${{ matrix.arch }}.app"
|
||||
if [ ! -d "$APP_PATH" ]; then
|
||||
APP_PATH=$(find build/bin -maxdepth 1 -name "*.app" | head -n 1 || true)
|
||||
fi
|
||||
if [ -z "$APP_PATH" ] || [ ! -d "$APP_PATH" ]; then
|
||||
echo "未找到 .app 产物"
|
||||
ls -la build/bin || true
|
||||
exit 1
|
||||
fi
|
||||
LABEL="${{ inputs.build_label }}"
|
||||
if [ -z "$LABEL" ]; then
|
||||
LABEL="test"
|
||||
fi
|
||||
ZIP_NAME="GoNavi-${LABEL}-macos-${{ matrix.arch }}-run${GITHUB_RUN_NUMBER}.zip"
|
||||
mkdir -p artifacts
|
||||
ditto -c -k --sequesterRsrc --keepParent "$APP_PATH" "artifacts/$ZIP_NAME"
|
||||
shasum -a 256 "artifacts/$ZIP_NAME" > "artifacts/$ZIP_NAME.sha256"
|
||||
|
||||
- name: Upload Artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: gonavi-macos-${{ matrix.arch }}-run${{ github.run_number }}
|
||||
path: artifacts/*
|
||||
if-no-files-found: error
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -17,8 +17,10 @@ dist/
|
||||
GoNavi-Wails
|
||||
GoNavi-Wails.exe
|
||||
.ace-tool/
|
||||
.superpowers/
|
||||
.claude/
|
||||
tmpclaude-*
|
||||
.gemini/
|
||||
**/tmpclaude-*
|
||||
|
||||
CLAUDE.md
|
||||
**/CLAUDE.md
|
||||
|
||||
@@ -79,7 +79,8 @@ Because external pull requests are merged directly into `main`, maintainers must
|
||||
|
||||
### 1. Sync `main` -> `dev` (required)
|
||||
|
||||
Every change merged into `main` must be synced into `dev`:
|
||||
The automatic GitHub Actions sync workflow has been removed.
|
||||
Maintainers should sync `main` back to `dev` manually when needed:
|
||||
|
||||
```bash
|
||||
git checkout dev
|
||||
@@ -114,7 +115,7 @@ git push origin v0.6.0
|
||||
|
||||
### 4. Sync `main` back to `dev` after release
|
||||
|
||||
After the release, sync `main` back into `dev` again:
|
||||
After the release, the same automation still applies. If needed, you can run the workflow manually (`workflow_dispatch`) or execute the fallback commands:
|
||||
|
||||
```bash
|
||||
git checkout dev
|
||||
|
||||
@@ -79,7 +79,8 @@ feature/* / fix/* -> dev -> release/* -> main -> tag(vX.Y.Z)
|
||||
|
||||
### 1. main → dev 同步(必做)
|
||||
|
||||
任何合入 `main` 的变更,都必须同步到 `dev`:
|
||||
仓库已移除 GitHub Actions 自动回灌 workflow。
|
||||
当前统一采用手动方式将 `main` 同步回 `dev`:
|
||||
|
||||
```bash
|
||||
git checkout dev
|
||||
@@ -114,7 +115,7 @@ git push origin v0.6.0
|
||||
|
||||
### 4. main 回流到 dev(发版后必做)
|
||||
|
||||
发布完成后,再次将 `main` 回流到 `dev`,确保开发线与发布线一致:
|
||||
发布完成后,仍沿用同一套自动化流程;如有需要,也可以手动触发 `workflow_dispatch`,或执行以下兜底命令,确保开发线与发布线一致:
|
||||
|
||||
```bash
|
||||
git checkout dev
|
||||
|
||||
@@ -154,6 +154,7 @@ Artifacts are generated in `build/bin`.
|
||||
|
||||
The repository includes a release workflow.
|
||||
Push a `v*` tag to trigger automated build and release.
|
||||
Release notes are generated automatically from merged pull requests and categorized by `.github/release.yaml`.
|
||||
|
||||
Target artifacts include:
|
||||
- macOS (AMD64 / ARM64)
|
||||
|
||||
@@ -147,6 +147,7 @@ wails build -clean
|
||||
### 跨平台发布(GitHub Actions)
|
||||
|
||||
仓库内置发布流水线,推送 `v*` Tag 可自动构建并发布 Release。
|
||||
Release 更新说明会基于已合并 Pull Request 自动生成,并按 `.github/release.yaml` 分类。
|
||||
|
||||
支持目标:
|
||||
- macOS (AMD64 / ARM64)
|
||||
|
||||
228
build-driver-agents.sh
Executable file
228
build-driver-agents.sh
Executable file
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
DEFAULT_DRIVERS=(mariadb doris sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase mongodb tdengine clickhouse)
|
||||
|
||||
usage() {
|
||||
cat <<'EOF'
|
||||
用法:
|
||||
./build-driver-agents.sh [选项]
|
||||
|
||||
选项:
|
||||
--drivers <列表> 指定驱动列表(逗号分隔),例如:kingbase,mongodb
|
||||
--platform <GOOS/GOARCH>
|
||||
目标平台,默认使用当前 Go 环境(go env GOOS/GOARCH)
|
||||
--out-dir <目录> 输出目录根路径,默认:dist/driver-agents
|
||||
--bundle-name <文件名> 驱动总包 zip 名称,默认:GoNavi-DriverAgents.zip
|
||||
--strict 任一驱动构建失败即中断(默认失败后继续,最后汇总)
|
||||
-h, --help 显示帮助
|
||||
|
||||
示例:
|
||||
./build-driver-agents.sh
|
||||
./build-driver-agents.sh --drivers kingbase
|
||||
./build-driver-agents.sh --platform windows/amd64 --drivers kingbase,mongodb
|
||||
EOF
|
||||
}
|
||||
|
||||
normalize_driver() {
|
||||
local name
|
||||
name="$(echo "${1:-}" | tr '[:upper:]' '[:lower:]' | xargs)"
|
||||
case "$name" in
|
||||
doris|diros) echo "doris" ;;
|
||||
mariadb|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|mongodb|tdengine|clickhouse)
|
||||
echo "$name"
|
||||
;;
|
||||
*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
build_driver_name() {
|
||||
case "$1" in
|
||||
doris) echo "diros" ;;
|
||||
*) echo "$1" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
platform_dir_name() {
|
||||
case "$1" in
|
||||
windows) echo "Windows" ;;
|
||||
darwin) echo "MacOS" ;;
|
||||
linux) echo "Linux" ;;
|
||||
*) echo "Unknown" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
driver_csv=""
|
||||
target_platform=""
|
||||
out_root="dist/driver-agents"
|
||||
bundle_name="GoNavi-DriverAgents.zip"
|
||||
strict_mode="false"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--drivers)
|
||||
driver_csv="${2:-}"
|
||||
shift 2
|
||||
;;
|
||||
--platform)
|
||||
target_platform="${2:-}"
|
||||
shift 2
|
||||
;;
|
||||
--out-dir)
|
||||
out_root="${2:-}"
|
||||
shift 2
|
||||
;;
|
||||
--bundle-name)
|
||||
bundle_name="${2:-}"
|
||||
shift 2
|
||||
;;
|
||||
--strict)
|
||||
strict_mode="true"
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "❌ 未知参数:$1"
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if ! command -v go >/dev/null 2>&1; then
|
||||
echo "❌ 未找到 Go,请先安装 Go 并确保 go 在 PATH 中。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "$target_platform" ]]; then
|
||||
target_platform="$(go env GOOS)/$(go env GOARCH)"
|
||||
fi
|
||||
|
||||
if [[ "$target_platform" != */* ]]; then
|
||||
echo "❌ --platform 参数格式错误,应为 GOOS/GOARCH,例如 darwin/arm64"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
goos="${target_platform%%/*}"
|
||||
goarch="${target_platform##*/}"
|
||||
platform_key="${goos}-${goarch}"
|
||||
platform_dir="$(platform_dir_name "$goos")"
|
||||
|
||||
declare -a drivers=()
|
||||
if [[ -n "$driver_csv" ]]; then
|
||||
IFS=',' read -r -a raw_drivers <<<"$driver_csv"
|
||||
for item in "${raw_drivers[@]}"; do
|
||||
normalized="$(normalize_driver "$item")" || {
|
||||
echo "❌ 不支持的驱动:$item"
|
||||
exit 1
|
||||
}
|
||||
drivers+=("$normalized")
|
||||
done
|
||||
else
|
||||
drivers=("${DEFAULT_DRIVERS[@]}")
|
||||
fi
|
||||
|
||||
output_dir="${out_root%/}/${platform_key}"
|
||||
bundle_stage_dir="$(mktemp -d "${TMPDIR:-/tmp}/gonavi-driver-bundle.XXXXXX")"
|
||||
bundle_platform_dir="$bundle_stage_dir/$platform_dir"
|
||||
|
||||
cleanup() {
|
||||
rm -rf "$bundle_stage_dir"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
mkdir -p "$output_dir" "$bundle_platform_dir"
|
||||
output_dir_abs="$(cd "$output_dir" && pwd)"
|
||||
bundle_zip_path="$output_dir_abs/$bundle_name"
|
||||
|
||||
declare -a built_assets=()
|
||||
declare -a failed_drivers=()
|
||||
declare -a skipped_drivers=()
|
||||
|
||||
echo "🚀 开始构建 optional-driver-agent"
|
||||
echo " 平台:$goos/$goarch"
|
||||
echo " 输出目录:$output_dir_abs"
|
||||
echo " 驱动列表:${drivers[*]}"
|
||||
|
||||
for driver in "${drivers[@]}"; do
|
||||
if [[ "$driver" == "duckdb" && "$goos" == "windows" && "$goarch" != "amd64" ]]; then
|
||||
echo "⚠️ 跳过 duckdb(仅支持 windows/amd64)"
|
||||
skipped_drivers+=("$driver")
|
||||
continue
|
||||
fi
|
||||
|
||||
build_driver="$(build_driver_name "$driver")"
|
||||
tag="gonavi_${build_driver}_driver"
|
||||
asset_name="${driver}-driver-agent-${goos}-${goarch}"
|
||||
if [[ "$goos" == "windows" ]]; then
|
||||
asset_name="${asset_name}.exe"
|
||||
fi
|
||||
output_path="$output_dir_abs/$asset_name"
|
||||
|
||||
cgo_enabled=0
|
||||
if [[ "$driver" == "duckdb" ]]; then
|
||||
cgo_enabled=1
|
||||
fi
|
||||
|
||||
echo "🔧 构建 $driver -> $asset_name (tag=$tag, CGO_ENABLED=$cgo_enabled)"
|
||||
set +e
|
||||
CGO_ENABLED="$cgo_enabled" GOOS="$goos" GOARCH="$goarch" GOTOOLCHAIN=auto \
|
||||
go build -tags "$tag" -trimpath -ldflags "-s -w" -o "$output_path" ./cmd/optional-driver-agent
|
||||
build_exit=$?
|
||||
set -e
|
||||
|
||||
if [[ $build_exit -ne 0 ]]; then
|
||||
echo "❌ 构建失败:$driver"
|
||||
failed_drivers+=("$driver")
|
||||
if [[ "$strict_mode" == "true" ]]; then
|
||||
exit $build_exit
|
||||
fi
|
||||
continue
|
||||
fi
|
||||
|
||||
cp "$output_path" "$bundle_platform_dir/$asset_name"
|
||||
built_assets+=("$asset_name")
|
||||
done
|
||||
|
||||
if [[ ${#built_assets[@]} -eq 0 ]]; then
|
||||
echo "❌ 未成功构建任何驱动代理。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -f "$bundle_zip_path"
|
||||
if command -v zip >/dev/null 2>&1; then
|
||||
(
|
||||
cd "$bundle_stage_dir"
|
||||
zip -qry "$bundle_zip_path" "$platform_dir"
|
||||
)
|
||||
elif command -v ditto >/dev/null 2>&1; then
|
||||
(
|
||||
cd "$bundle_stage_dir"
|
||||
ditto -c -k --sequesterRsrc --keepParent "$platform_dir" "$bundle_zip_path"
|
||||
)
|
||||
else
|
||||
echo "❌ 未找到 zip/ditto,无法生成驱动总包 zip。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "✅ 构建完成"
|
||||
echo " 单文件输出目录:$output_dir_abs"
|
||||
echo " 驱动总包:$bundle_zip_path"
|
||||
echo " 已构建:${built_assets[*]}"
|
||||
if [[ ${#skipped_drivers[@]} -gt 0 ]]; then
|
||||
echo " 已跳过:${skipped_drivers[*]}"
|
||||
fi
|
||||
if [[ ${#failed_drivers[@]} -gt 0 ]]; then
|
||||
echo "⚠️ 构建失败驱动:${failed_drivers[*]}"
|
||||
exit 2
|
||||
fi
|
||||
351
build-release.sh
351
build-release.sh
@@ -20,6 +20,75 @@ RED='\033[0;31m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
get_file_size_bytes() {
|
||||
local target="$1"
|
||||
if [ ! -f "$target" ]; then
|
||||
echo 0
|
||||
return
|
||||
fi
|
||||
if stat -f%z "$target" >/dev/null 2>&1; then
|
||||
stat -f%z "$target"
|
||||
return
|
||||
fi
|
||||
if stat -c%s "$target" >/dev/null 2>&1; then
|
||||
stat -c%s "$target"
|
||||
return
|
||||
fi
|
||||
wc -c <"$target" | tr -d '[:space:]'
|
||||
}
|
||||
|
||||
format_size_mb() {
|
||||
local bytes="${1:-0}"
|
||||
awk -v b="$bytes" 'BEGIN { printf "%.2fMB", b / 1024 / 1024 }'
|
||||
}
|
||||
|
||||
try_compress_binary_with_upx() {
|
||||
local exe_path="$1"
|
||||
local label="$2"
|
||||
if [ ! -f "$exe_path" ]; then
|
||||
echo -e "${RED} ❌ 未找到 ${label} 文件:$exe_path${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v upx >/dev/null 2>&1; then
|
||||
echo -e "${RED} ❌ 未找到 upx,${label} 必须进行压缩后才能继续打包。${NC}"
|
||||
case "$(uname -s)" in
|
||||
Darwin)
|
||||
echo " 安装命令: brew install upx"
|
||||
;;
|
||||
Linux)
|
||||
echo " 安装命令: sudo apt-get install -y upx-ucl (或对应发行版包管理器)"
|
||||
;;
|
||||
esac
|
||||
exit 1
|
||||
fi
|
||||
|
||||
local before_bytes after_bytes
|
||||
before_bytes=$(get_file_size_bytes "$exe_path")
|
||||
echo " 🗜️ 正在使用 UPX 压缩 ${label}..."
|
||||
if upx --best --lzma --force "$exe_path" >/dev/null 2>&1; then
|
||||
if ! upx -t "$exe_path" >/dev/null 2>&1; then
|
||||
echo -e "${RED} ❌ UPX 校验失败:${label}${NC}"
|
||||
exit 1
|
||||
fi
|
||||
after_bytes=$(get_file_size_bytes "$exe_path")
|
||||
if [ "$after_bytes" -lt "$before_bytes" ]; then
|
||||
local saved_bytes=$((before_bytes - after_bytes))
|
||||
echo " ✅ UPX 压缩完成: $(format_size_mb "$before_bytes") -> $(format_size_mb "$after_bytes"),减少 $(format_size_mb "$saved_bytes")"
|
||||
else
|
||||
echo " ℹ️ UPX 压缩完成: $(format_size_mb "$before_bytes") -> $(format_size_mb "$after_bytes")"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED} ❌ UPX 压缩失败:${label}${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
MAC_VOLICON_PATH="build/darwin/icon.icns"
|
||||
if [ ! -f "$MAC_VOLICON_PATH" ]; then
|
||||
MAC_VOLICON_PATH=""
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}🚀 开始构建 $APP_NAME $VERSION...${NC}"
|
||||
|
||||
# 清理并创建输出目录
|
||||
@@ -36,47 +105,101 @@ if [ $? -eq 0 ]; then
|
||||
|
||||
# 移动 .app 到 dist
|
||||
mv "$APP_SRC" "$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
APP_BIN_PATH=$(find "$DIST_DIR/$APP_DEST_NAME/Contents/MacOS" -maxdepth 1 -type f -print -quit)
|
||||
if [ -n "$APP_BIN_PATH" ] && [ -f "$APP_BIN_PATH" ]; then
|
||||
echo -e "${YELLOW} ⚠️ macOS arm64 不再执行 UPX 压缩,保留原始主程序。${NC}"
|
||||
else
|
||||
echo -e "${RED} ❌ 未找到 macOS arm64 主程序文件。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 创建 DMG
|
||||
if command -v create-dmg &> /dev/null; then
|
||||
echo " 📦 正在打包 DMG (arm64)..."
|
||||
# 移除已存在的 DMG (以防万一)
|
||||
rm -f "$DIST_DIR/$DMG_NAME"
|
||||
|
||||
create-dmg \
|
||||
--volname "${APP_NAME} ${VERSION}" \
|
||||
--volicon "build/appicon.icns" \
|
||||
--window-pos 200 120 \
|
||||
--window-size 800 400 \
|
||||
--icon-size 100 \
|
||||
--icon "$APP_DEST_NAME" 200 190 \
|
||||
--hide-extension "$APP_DEST_NAME" \
|
||||
--app-drop-link 600 185 \
|
||||
"$DIST_DIR/$DMG_NAME" \
|
||||
"$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
# 检查是否生成了 rw.* 的临时文件并重命名 (create-dmg 有时会有此行为)
|
||||
if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
RW_FILE=$(find "$DIST_DIR" -name "rw.*.dmg" -print -quit)
|
||||
if [ -n "$RW_FILE" ]; then
|
||||
echo -e "${YELLOW} ⚠️ 检测到临时文件名,正在重命名...${NC}"
|
||||
mv "$RW_FILE" "$DIST_DIR/$DMG_NAME"
|
||||
fi
|
||||
# Ad-hoc 代码签名(无 Apple Developer 账号时防止 Gatekeeper 报已损坏)
|
||||
echo " 🔏 正在对 .app 进行 ad-hoc 签名 (arm64)..."
|
||||
codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
# 创建 DMG
|
||||
if command -v create-dmg &> /dev/null; then
|
||||
echo " 📦 正在打包 DMG (arm64)..."
|
||||
# 移除已存在的 DMG (以防万一)
|
||||
rm -f "$DIST_DIR/$DMG_NAME"
|
||||
# create-dmg 的 source 需要是“包含 .app 的目录”,不能直接传 .app 路径。
|
||||
STAGE_DIR=$(mktemp -d "$DIST_DIR/.dmg-stage-${APP_NAME}-${VERSION}-arm64.XXXXXX")
|
||||
if [ -z "$STAGE_DIR" ] || [ ! -d "$STAGE_DIR" ]; then
|
||||
echo -e "${RED} ❌ 创建 DMG 临时目录失败,跳过 DMG 打包。${NC}"
|
||||
else
|
||||
if command -v ditto &> /dev/null; then
|
||||
ditto "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME"
|
||||
else
|
||||
cp -R "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME"
|
||||
fi
|
||||
|
||||
# --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。
|
||||
CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe)
|
||||
if [ -n "$MAC_VOLICON_PATH" ]; then
|
||||
CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH")
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 macOS 卷图标 (build/darwin/icon.icns),跳过 --volicon。${NC}"
|
||||
fi
|
||||
|
||||
# 删除中间的 .app 文件,保持目录整洁
|
||||
rm -rf "$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
if [ -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
echo " ✅ 已生成 $DMG_NAME"
|
||||
else
|
||||
echo -e "${RED} ❌ DMG 生成失败,请检查 create-dmg 输出。${NC}"
|
||||
create-dmg "${CREATE_DMG_ARGS[@]}" \
|
||||
--window-pos 200 120 \
|
||||
--window-size 800 400 \
|
||||
--icon-size 100 \
|
||||
--icon "$APP_DEST_NAME" 200 190 \
|
||||
--hide-extension "$APP_DEST_NAME" \
|
||||
--app-drop-link 600 185 \
|
||||
"$DIST_DIR/$DMG_NAME" \
|
||||
"$STAGE_DIR"
|
||||
|
||||
CREATE_DMG_EXIT_CODE=$?
|
||||
rm -rf "$STAGE_DIR"
|
||||
|
||||
if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then
|
||||
echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}"
|
||||
else
|
||||
# create-dmg 可能会在失败时遗留 rw.*.dmg 中间产物;不要直接当作最终产物使用
|
||||
if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
RW_FILE=$(find "$DIST_DIR" -maxdepth 1 -name "rw.*.dmg" -print -quit)
|
||||
if [ -n "$RW_FILE" ]; then
|
||||
echo -e "${YELLOW} ⚠️ 检测到 create-dmg 中间产物: $(basename "$RW_FILE"),正在转换为可分发 DMG...${NC}"
|
||||
hdiutil convert "$RW_FILE" -format UDZO -o "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1
|
||||
rm -f "$RW_FILE"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 防御性:即使生成了目标文件,也要确保不是 UDRW(UDRW 在 Finder 下可能表现为“已损坏/无法打开”)
|
||||
if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then
|
||||
DMG_FORMAT=$(hdiutil imageinfo "$DIST_DIR/$DMG_NAME" 2>/dev/null | awk -F': ' '/^Format:/{print $2; exit}')
|
||||
if [ "$DMG_FORMAT" = "UDRW" ]; then
|
||||
echo -e "${YELLOW} ⚠️ 检测到 UDRW(可写原始映像),正在转换为 UDZO...${NC}"
|
||||
TMP_UDZO="$DIST_DIR/.tmp.$DMG_NAME"
|
||||
rm -f "$TMP_UDZO"
|
||||
hdiutil convert "$DIST_DIR/$DMG_NAME" -format UDZO -o "$TMP_UDZO" >/dev/null 2>&1 && mv "$TMP_UDZO" "$DIST_DIR/$DMG_NAME"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then
|
||||
hdiutil verify "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "${RED} ❌ DMG 校验失败,保留 .app 以便排查。${NC}"
|
||||
else
|
||||
# 删除中间的 .app 文件,保持目录整洁
|
||||
rm -rf "$DIST_DIR/$APP_DEST_NAME"
|
||||
echo " ✅ 已生成 $DMG_NAME"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具,跳过 DMG 打包,仅保留 .app。${NC}"
|
||||
echo " 安装命令: brew install create-dmg"
|
||||
fi
|
||||
else
|
||||
|
||||
if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
echo -e "${RED} ❌ DMG 生成失败,请检查 create-dmg 输出。${NC}"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具,跳过 DMG 打包,仅保留 .app。${NC}"
|
||||
echo " 安装命令: brew install create-dmg"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED} ❌ macOS arm64 构建失败。${NC}"
|
||||
fi
|
||||
|
||||
@@ -89,44 +212,96 @@ if [ $? -eq 0 ]; then
|
||||
DMG_NAME="${APP_NAME}-${VERSION}-mac-amd64.dmg"
|
||||
|
||||
mv "$APP_SRC" "$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
if command -v create-dmg &> /dev/null; then
|
||||
echo " 📦 正在打包 DMG (amd64)..."
|
||||
rm -f "$DIST_DIR/$DMG_NAME"
|
||||
|
||||
create-dmg \
|
||||
--volname "${APP_NAME} ${VERSION}" \
|
||||
--volicon "build/appicon.icns" \
|
||||
--window-pos 200 120 \
|
||||
--window-size 800 400 \
|
||||
--icon-size 100 \
|
||||
--icon "$APP_DEST_NAME" 200 190 \
|
||||
--hide-extension "$APP_DEST_NAME" \
|
||||
--app-drop-link 600 185 \
|
||||
"$DIST_DIR/$DMG_NAME" \
|
||||
"$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
# 检查是否生成了 rw.* 的临时文件并重命名
|
||||
if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
RW_FILE=$(find "$DIST_DIR" -name "rw.*.dmg" -print -quit)
|
||||
if [ -n "$RW_FILE" ]; then
|
||||
echo -e "${YELLOW} ⚠️ 检测到临时文件名,正在重命名...${NC}"
|
||||
mv "$RW_FILE" "$DIST_DIR/$DMG_NAME"
|
||||
fi
|
||||
fi
|
||||
|
||||
rm -rf "$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
if [ -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
echo " ✅ 已生成 $DMG_NAME"
|
||||
else
|
||||
echo -e "${RED} ❌ DMG 生成失败。${NC}"
|
||||
fi
|
||||
APP_BIN_PATH=$(find "$DIST_DIR/$APP_DEST_NAME/Contents/MacOS" -maxdepth 1 -type f -print -quit)
|
||||
if [ -n "$APP_BIN_PATH" ] && [ -f "$APP_BIN_PATH" ]; then
|
||||
echo -e "${YELLOW} ⚠️ macOS amd64 不再执行 UPX 压缩,保留原始主程序。${NC}"
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具。${NC}"
|
||||
echo -e "${RED} ❌ 未找到 macOS amd64 主程序文件。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED} ❌ macOS amd64 构建失败。${NC}"
|
||||
|
||||
# Ad-hoc 代码签名
|
||||
echo " 🔏 正在对 .app 进行 ad-hoc 签名 (amd64)..."
|
||||
codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME"
|
||||
|
||||
if command -v create-dmg &> /dev/null; then
|
||||
echo " 📦 正在打包 DMG (amd64)..."
|
||||
rm -f "$DIST_DIR/$DMG_NAME"
|
||||
# create-dmg 的 source 需要是“包含 .app 的目录”,不能直接传 .app 路径。
|
||||
STAGE_DIR=$(mktemp -d "$DIST_DIR/.dmg-stage-${APP_NAME}-${VERSION}-amd64.XXXXXX")
|
||||
if [ -z "$STAGE_DIR" ] || [ ! -d "$STAGE_DIR" ]; then
|
||||
echo -e "${RED} ❌ 创建 DMG 临时目录失败,跳过 DMG 打包。${NC}"
|
||||
else
|
||||
if command -v ditto &> /dev/null; then
|
||||
ditto "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME"
|
||||
else
|
||||
cp -R "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME"
|
||||
fi
|
||||
|
||||
# --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。
|
||||
CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe)
|
||||
if [ -n "$MAC_VOLICON_PATH" ]; then
|
||||
CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH")
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 macOS 卷图标 (build/darwin/icon.icns),跳过 --volicon。${NC}"
|
||||
fi
|
||||
|
||||
create-dmg "${CREATE_DMG_ARGS[@]}" \
|
||||
--window-pos 200 120 \
|
||||
--window-size 800 400 \
|
||||
--icon-size 100 \
|
||||
--icon "$APP_DEST_NAME" 200 190 \
|
||||
--hide-extension "$APP_DEST_NAME" \
|
||||
--app-drop-link 600 185 \
|
||||
"$DIST_DIR/$DMG_NAME" \
|
||||
"$STAGE_DIR"
|
||||
|
||||
CREATE_DMG_EXIT_CODE=$?
|
||||
rm -rf "$STAGE_DIR"
|
||||
|
||||
if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then
|
||||
echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}"
|
||||
else
|
||||
if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
RW_FILE=$(find "$DIST_DIR" -maxdepth 1 -name "rw.*.dmg" -print -quit)
|
||||
if [ -n "$RW_FILE" ]; then
|
||||
echo -e "${YELLOW} ⚠️ 检测到 create-dmg 中间产物: $(basename "$RW_FILE"),正在转换为可分发 DMG...${NC}"
|
||||
hdiutil convert "$RW_FILE" -format UDZO -o "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1
|
||||
rm -f "$RW_FILE"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then
|
||||
DMG_FORMAT=$(hdiutil imageinfo "$DIST_DIR/$DMG_NAME" 2>/dev/null | awk -F': ' '/^Format:/{print $2; exit}')
|
||||
if [ "$DMG_FORMAT" = "UDRW" ]; then
|
||||
echo -e "${YELLOW} ⚠️ 检测到 UDRW(可写原始映像),正在转换为 UDZO...${NC}"
|
||||
TMP_UDZO="$DIST_DIR/.tmp.$DMG_NAME"
|
||||
rm -f "$TMP_UDZO"
|
||||
hdiutil convert "$DIST_DIR/$DMG_NAME" -format UDZO -o "$TMP_UDZO" >/dev/null 2>&1 && mv "$TMP_UDZO" "$DIST_DIR/$DMG_NAME"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f "$DIST_DIR/$DMG_NAME" ] && command -v hdiutil &> /dev/null; then
|
||||
hdiutil verify "$DIST_DIR/$DMG_NAME" >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "${RED} ❌ DMG 校验失败,保留 .app 以便排查。${NC}"
|
||||
else
|
||||
rm -rf "$DIST_DIR/$APP_DEST_NAME"
|
||||
echo " ✅ 已生成 $DMG_NAME"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then
|
||||
echo -e "${RED} ❌ DMG 生成失败。${NC}"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具。${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED} ❌ macOS amd64 构建失败。${NC}"
|
||||
fi
|
||||
|
||||
# --- Windows AMD64 构建 ---
|
||||
@@ -134,7 +309,9 @@ echo -e "${GREEN}🪟 正在构建 Windows (amd64)...${NC}"
|
||||
if command -v x86_64-w64-mingw32-gcc &> /dev/null; then
|
||||
wails build -platform windows/amd64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe"
|
||||
TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe"
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$TARGET_EXE"
|
||||
try_compress_binary_with_upx "$TARGET_EXE" "Windows amd64 可执行文件"
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-amd64.exe"
|
||||
else
|
||||
echo -e "${RED} ❌ Windows amd64 构建失败。${NC}"
|
||||
@@ -148,7 +325,9 @@ echo -e "${GREEN}🪟 正在构建 Windows (arm64)...${NC}"
|
||||
if command -v aarch64-w64-mingw32-gcc &> /dev/null; then
|
||||
wails build -platform windows/arm64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe"
|
||||
TARGET_EXE="$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe"
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$TARGET_EXE"
|
||||
echo -e "${YELLOW} ⚠️ 当前 UPX 不支持 win64/arm64,跳过 Windows arm64 压缩。${NC}"
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-arm64.exe"
|
||||
else
|
||||
echo -e "${RED} ❌ Windows arm64 构建失败。${NC}"
|
||||
@@ -168,8 +347,10 @@ if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "x86_64" ]; then
|
||||
# 本机 Linux amd64,直接构建
|
||||
wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN"
|
||||
chmod +x "$TARGET_LINUX_BIN"
|
||||
try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux amd64 可执行文件"
|
||||
# 打包为 tar.gz
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64"
|
||||
@@ -186,8 +367,10 @@ elif command -v x86_64-linux-gnu-gcc &> /dev/null; then
|
||||
export CGO_ENABLED=1
|
||||
wails build -platform linux/amd64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN"
|
||||
chmod +x "$TARGET_LINUX_BIN"
|
||||
try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux amd64 可执行文件"
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64"
|
||||
rm "${APP_NAME}-${VERSION}-linux-amd64"
|
||||
@@ -208,8 +391,10 @@ if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "aarch64" ]; then
|
||||
# 本机 Linux arm64,直接构建
|
||||
wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN"
|
||||
chmod +x "$TARGET_LINUX_BIN"
|
||||
try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux arm64 可执行文件"
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
rm "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
@@ -225,8 +410,10 @@ elif command -v aarch64-linux-gnu-gcc &> /dev/null; then
|
||||
export CGO_ENABLED=1
|
||||
wails build -platform linux/arm64 -clean -ldflags "$LDFLAGS"
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
TARGET_LINUX_BIN="$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$TARGET_LINUX_BIN"
|
||||
chmod +x "$TARGET_LINUX_BIN"
|
||||
try_compress_binary_with_upx "$TARGET_LINUX_BIN" "Linux arm64 可执行文件"
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
rm "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
|
||||
@@ -5,6 +5,23 @@
|
||||
<link rel="icon" type="image/svg+xml" href="/logo.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>GoNavi</title>
|
||||
<script>
|
||||
if (typeof window !== 'undefined' && !window.go) {
|
||||
window.go = {
|
||||
app: {
|
||||
App: new Proxy({}, { get: () => async () => ({ success: false }) })
|
||||
}
|
||||
};
|
||||
}
|
||||
if (typeof window !== 'undefined' && !window.runtime) {
|
||||
window.runtime = new Proxy({}, {
|
||||
get: (target, prop) => {
|
||||
if (prop === 'Environment') return async () => ({ platform: 'darwin' });
|
||||
return typeof prop === 'string' && prop.startsWith('WindowIs') ? () => false : () => {};
|
||||
}
|
||||
});
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
||||
@@ -37,6 +37,91 @@ body, #root {
|
||||
padding-right: 8px;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-tree {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-list-holder-inner,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-list-holder-inner .ant-tree-treenode {
|
||||
width: 100% !important;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-node-content-wrapper {
|
||||
min-height: 36px;
|
||||
border-radius: 14px;
|
||||
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
|
||||
background: transparent !important;
|
||||
border: none !important;
|
||||
box-shadow: none !important;
|
||||
outline: none !important;
|
||||
flex: 1 1 auto;
|
||||
min-width: 0;
|
||||
width: auto !important;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-node-content-wrapper:hover,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-node-content-wrapper:active,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-node-content-wrapper:focus,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-node-content-wrapper:focus-visible,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-node-content-wrapper.ant-tree-node-selected,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-node-content-wrapper.ant-tree-node-selected:hover {
|
||||
background: transparent !important;
|
||||
border-color: transparent !important;
|
||||
box-shadow: none !important;
|
||||
outline: none !important;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-treenode {
|
||||
padding: 2px 0;
|
||||
width: 100%;
|
||||
border-radius: 14px;
|
||||
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
|
||||
border: none;
|
||||
align-items: center;
|
||||
position: relative;
|
||||
z-index: 0;
|
||||
display: flex !important;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-switcher {
|
||||
width: 0 !important;
|
||||
min-width: 0 !important;
|
||||
margin-inline-end: 0 !important;
|
||||
padding: 0 !important;
|
||||
overflow: hidden !important;
|
||||
background: transparent !important;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-switcher:hover,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-switcher:active,
|
||||
.redis-viewer-workbench .ant-tree .ant-tree-switcher:focus {
|
||||
background: transparent !important;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .redis-tree-expander-button:hover,
|
||||
.redis-viewer-workbench .redis-tree-expander-button:focus-visible {
|
||||
background: transparent !important;
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-radio-group .ant-radio-button-wrapper {
|
||||
border-radius: 10px;
|
||||
margin-inline-end: 6px;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-radio-group .ant-radio-button-wrapper:last-child {
|
||||
margin-inline-end: 0;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-table {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.redis-viewer-workbench .ant-table-wrapper .ant-table-thead > tr > th {
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
/* Scrollbar styling for dark mode */
|
||||
body[data-theme='dark'] ::-webkit-scrollbar {
|
||||
width: 10px;
|
||||
@@ -97,6 +182,16 @@ body[data-theme='dark'] .ant-tree .ant-tree-node-content-wrapper.ant-tree-node-s
|
||||
color: rgba(255, 236, 179, 0.98) !important;
|
||||
}
|
||||
|
||||
body[data-theme='dark'] .redis-viewer-workbench .ant-tree .ant-tree-treenode:hover {
|
||||
background: rgba(255, 255, 255, 0.05) !important;
|
||||
}
|
||||
|
||||
body[data-theme='dark'] .redis-viewer-workbench .ant-tree .ant-tree-treenode.ant-tree-treenode-selected,
|
||||
body[data-theme='dark'] .redis-viewer-workbench .ant-tree .ant-tree-treenode.ant-tree-treenode-selected:hover {
|
||||
background: linear-gradient(90deg, rgba(246, 196, 83, 0.22), rgba(246, 196, 83, 0.08)) !important;
|
||||
border: 1px solid rgba(246, 196, 83, 0.24) !important;
|
||||
}
|
||||
|
||||
body[data-theme='dark'] .ant-checkbox-checked .ant-checkbox-inner {
|
||||
background-color: #f6c453 !important;
|
||||
border-color: #f6c453 !important;
|
||||
@@ -135,6 +230,41 @@ body[data-theme='dark'] .ant-table-tbody .ant-table-row.ant-table-row-selected:h
|
||||
background: rgba(246, 196, 83, 0.26) !important;
|
||||
}
|
||||
|
||||
body[data-theme='dark'] .redis-viewer-workbench .ant-radio-button-wrapper {
|
||||
background: rgba(255, 255, 255, 0.04);
|
||||
border-color: rgba(255, 255, 255, 0.08);
|
||||
color: rgba(230, 234, 242, 0.9);
|
||||
}
|
||||
|
||||
body[data-theme='dark'] .redis-viewer-workbench .ant-radio-button-wrapper-checked:not(.ant-radio-button-wrapper-disabled) {
|
||||
background: rgba(246, 196, 83, 0.16);
|
||||
border-color: rgba(246, 196, 83, 0.3);
|
||||
color: #f6c453;
|
||||
}
|
||||
|
||||
body[data-theme='light'] .redis-viewer-workbench .ant-tree .ant-tree-treenode:hover {
|
||||
background: rgba(15, 23, 42, 0.04) !important;
|
||||
}
|
||||
|
||||
body[data-theme='light'] .redis-viewer-workbench .ant-tree .ant-tree-treenode.ant-tree-treenode-selected,
|
||||
body[data-theme='light'] .redis-viewer-workbench .ant-tree .ant-tree-treenode.ant-tree-treenode-selected:hover {
|
||||
color: rgba(15, 23, 42, 0.92) !important;
|
||||
background: linear-gradient(90deg, rgba(22, 119, 255, 0.12), rgba(22, 119, 255, 0.04)) !important;
|
||||
border: 1px solid rgba(22, 119, 255, 0.18) !important;
|
||||
}
|
||||
|
||||
body[data-theme='light'] .redis-viewer-workbench .ant-radio-button-wrapper {
|
||||
background: rgba(255, 255, 255, 0.72);
|
||||
border-color: rgba(15, 23, 42, 0.08);
|
||||
color: rgba(51, 65, 85, 0.88);
|
||||
}
|
||||
|
||||
body[data-theme='light'] .redis-viewer-workbench .ant-radio-button-wrapper-checked:not(.ant-radio-button-wrapper-disabled) {
|
||||
background: rgba(22, 119, 255, 0.1);
|
||||
border-color: rgba(22, 119, 255, 0.22);
|
||||
color: #1677ff;
|
||||
}
|
||||
|
||||
/* 连接配置弹窗:滚动仅在弹窗 body 内部,不使用外层 wrap 滚动条 */
|
||||
.connection-modal-wrap {
|
||||
overflow: hidden !important;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,11 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import { Modal, Form, Select, Button, message, Steps, Transfer, Card, Alert, Divider, Typography, Progress, Checkbox, Table, Drawer, Tabs } from 'antd';
|
||||
import React, { useState, useEffect, useMemo, useRef } from 'react';
|
||||
import { Modal, Form, Select, Input, Button, message, Steps, Transfer, Card, Alert, Divider, Typography, Progress, Checkbox, Table, Drawer, Tabs, theme as antdTheme } from 'antd';
|
||||
import { DatabaseOutlined, RocketOutlined, SwapOutlined, TableOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { DBGetDatabases, DBGetTables, DataSync, DataSyncAnalyze, DataSyncPreview } from '../../wailsjs/go/app/App';
|
||||
import { SavedConnection } from '../types';
|
||||
import { EventsOn } from '../../wailsjs/runtime/runtime';
|
||||
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
const { Step } = Steps;
|
||||
@@ -21,6 +23,12 @@ type TableDiffSummary = {
|
||||
deletes?: number;
|
||||
same?: number;
|
||||
message?: string;
|
||||
targetTableExists?: boolean;
|
||||
plannedAction?: string;
|
||||
warnings?: string[];
|
||||
unsupportedObjects?: string[];
|
||||
indexesToCreate?: number;
|
||||
indexesSkipped?: number;
|
||||
};
|
||||
type TableOps = {
|
||||
insert: boolean;
|
||||
@@ -31,10 +39,135 @@ type TableOps = {
|
||||
selectedDeletePks?: string[];
|
||||
};
|
||||
|
||||
type WorkflowType = 'sync' | 'migration';
|
||||
|
||||
const quoteSqlIdent = (dbType: string, ident: string): string => {
|
||||
const raw = String(ident || '').trim();
|
||||
if (!raw) return raw;
|
||||
const t = String(dbType || '').toLowerCase();
|
||||
if (t === 'mysql' || t === 'mariadb' || t === 'diros' || t === 'sphinx' || t === 'clickhouse' || t === 'tdengine') {
|
||||
return `\`${raw.replace(/`/g, '``')}\``;
|
||||
}
|
||||
if (t === 'sqlserver') {
|
||||
return `[${raw.replace(/]/g, ']]')}]`;
|
||||
}
|
||||
return `"${raw.replace(/"/g, '""')}"`;
|
||||
};
|
||||
|
||||
const quoteSqlTable = (dbType: string, tableName: string): string => {
|
||||
const raw = String(tableName || '').trim();
|
||||
if (!raw) return raw;
|
||||
if (!raw.includes('.')) return quoteSqlIdent(dbType, raw);
|
||||
return raw
|
||||
.split('.')
|
||||
.map((part) => quoteSqlIdent(dbType, part))
|
||||
.join('.');
|
||||
};
|
||||
|
||||
const toSqlLiteral = (value: any, dbType: string): string => {
|
||||
if (value === null || value === undefined) return 'NULL';
|
||||
if (typeof value === 'number') return Number.isFinite(value) ? String(value) : 'NULL';
|
||||
if (typeof value === 'bigint') return value.toString();
|
||||
if (typeof value === 'boolean') {
|
||||
const t = String(dbType || '').toLowerCase();
|
||||
if (t === 'sqlserver') return value ? '1' : '0';
|
||||
return value ? 'TRUE' : 'FALSE';
|
||||
}
|
||||
if (value instanceof Date) {
|
||||
return `'${value.toISOString().replace(/'/g, "''")}'`;
|
||||
}
|
||||
if (typeof value === 'object') {
|
||||
try {
|
||||
return `'${JSON.stringify(value).replace(/'/g, "''")}'`;
|
||||
} catch {
|
||||
return `'${String(value).replace(/'/g, "''")}'`;
|
||||
}
|
||||
}
|
||||
return `'${String(value).replace(/'/g, "''")}'`;
|
||||
};
|
||||
|
||||
const resolveRedisDbIndex = (raw?: string): number => {
|
||||
const value = Number(String(raw || '').trim());
|
||||
return Number.isInteger(value) && value >= 0 && value <= 15 ? value : 0;
|
||||
};
|
||||
|
||||
const buildSqlPreview = (
|
||||
previewData: any,
|
||||
tableName: string,
|
||||
dbType: string,
|
||||
ops?: TableOps,
|
||||
): { sqlText: string; statementCount: number } => {
|
||||
if (!previewData || !tableName) return { sqlText: '', statementCount: 0 };
|
||||
const tableExpr = quoteSqlTable(dbType, tableName);
|
||||
const pkCol = String(previewData.pkColumn || 'id');
|
||||
const statements: string[] = [];
|
||||
|
||||
const insertRows = Array.isArray(previewData.inserts) ? previewData.inserts : [];
|
||||
const updateRows = Array.isArray(previewData.updates) ? previewData.updates : [];
|
||||
const deleteRows = Array.isArray(previewData.deletes) ? previewData.deletes : [];
|
||||
|
||||
const selectedInsert = new Set((ops?.selectedInsertPks || []).map((v) => String(v)));
|
||||
const selectedUpdate = new Set((ops?.selectedUpdatePks || []).map((v) => String(v)));
|
||||
const selectedDelete = new Set((ops?.selectedDeletePks || []).map((v) => String(v)));
|
||||
|
||||
if (ops?.insert !== false) {
|
||||
insertRows.forEach((rowWrap: any) => {
|
||||
const pk = String(rowWrap?.pk ?? '');
|
||||
if (selectedInsert.size > 0 && !selectedInsert.has(pk)) return;
|
||||
const row = rowWrap?.row || {};
|
||||
const columns = Object.keys(row);
|
||||
if (columns.length === 0) return;
|
||||
const colExpr = columns.map((c) => quoteSqlIdent(dbType, c)).join(', ');
|
||||
const valExpr = columns.map((c) => toSqlLiteral(row[c], dbType)).join(', ');
|
||||
statements.push(`INSERT INTO ${tableExpr} (${colExpr}) VALUES (${valExpr});`);
|
||||
});
|
||||
}
|
||||
|
||||
if (ops?.update !== false) {
|
||||
updateRows.forEach((rowWrap: any) => {
|
||||
const pk = String(rowWrap?.pk ?? '');
|
||||
if (selectedUpdate.size > 0 && !selectedUpdate.has(pk)) return;
|
||||
const source = rowWrap?.source || {};
|
||||
const changedColumns = Array.isArray(rowWrap?.changedColumns)
|
||||
? rowWrap.changedColumns
|
||||
: Object.keys(source).filter((k) => k !== pkCol);
|
||||
const setCols = changedColumns.filter((c: string) => String(c) !== pkCol);
|
||||
if (setCols.length === 0) return;
|
||||
const setExpr = setCols
|
||||
.map((c: string) => `${quoteSqlIdent(dbType, c)} = ${toSqlLiteral(source[c], dbType)}`)
|
||||
.join(', ');
|
||||
statements.push(
|
||||
`UPDATE ${tableExpr} SET ${setExpr} WHERE ${quoteSqlIdent(dbType, pkCol)} = ${toSqlLiteral(pk, dbType)};`,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
if (ops?.delete) {
|
||||
deleteRows.forEach((rowWrap: any) => {
|
||||
const pk = String(rowWrap?.pk ?? '');
|
||||
if (selectedDelete.size > 0 && !selectedDelete.has(pk)) return;
|
||||
statements.push(
|
||||
`DELETE FROM ${tableExpr} WHERE ${quoteSqlIdent(dbType, pkCol)} = ${toSqlLiteral(pk, dbType)};`,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
sqlText: statements.join('\n'),
|
||||
statementCount: statements.length,
|
||||
};
|
||||
};
|
||||
|
||||
const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, onClose }) => {
|
||||
const connections = useStore((state) => state.connections);
|
||||
const themeMode = useStore((state) => state.theme);
|
||||
const appearance = useStore((state) => state.appearance);
|
||||
const [currentStep, setCurrentStep] = useState(0);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const { token } = antdTheme.useToken();
|
||||
const darkMode = themeMode === 'dark';
|
||||
const resolvedAppearance = resolveAppearanceValues(appearance);
|
||||
const effectiveOpacity = normalizeOpacityForPlatform(resolvedAppearance.opacity);
|
||||
|
||||
// Step 1: Config
|
||||
const [sourceConnId, setSourceConnId] = useState<string>('');
|
||||
@@ -50,9 +183,13 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const [selectedTables, setSelectedTables] = useState<string[]>([]);
|
||||
|
||||
// Options
|
||||
const [workflowType, setWorkflowType] = useState<WorkflowType>('sync');
|
||||
const [syncContent, setSyncContent] = useState<'data' | 'schema' | 'both'>('data');
|
||||
const [syncMode, setSyncMode] = useState<string>('insert_update');
|
||||
const [autoAddColumns, setAutoAddColumns] = useState<boolean>(true);
|
||||
const [targetTableStrategy, setTargetTableStrategy] = useState<'existing_only' | 'auto_create_if_missing' | 'smart'>('existing_only');
|
||||
const [createIndexes, setCreateIndexes] = useState<boolean>(false);
|
||||
const [mongoCollectionName, setMongoCollectionName] = useState<string>('');
|
||||
const [showSameTables, setShowSameTables] = useState<boolean>(false);
|
||||
const [analyzing, setAnalyzing] = useState<boolean>(false);
|
||||
const [diffTables, setDiffTables] = useState<TableDiffSummary[]>([]);
|
||||
@@ -128,9 +265,12 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
setSourceDb('');
|
||||
setTargetDb('');
|
||||
setSelectedTables([]);
|
||||
setWorkflowType('sync');
|
||||
setSyncContent('data');
|
||||
setSyncMode('insert_update');
|
||||
setAutoAddColumns(true);
|
||||
setTargetTableStrategy('existing_only');
|
||||
setCreateIndexes(false);
|
||||
setShowSameTables(false);
|
||||
setAnalyzing(false);
|
||||
setDiffTables([]);
|
||||
@@ -148,36 +288,66 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
}
|
||||
}, [open]);
|
||||
|
||||
useEffect(() => {
|
||||
if (workflowType === 'migration') {
|
||||
if (syncMode === 'insert_update') {
|
||||
setSyncMode('insert_only');
|
||||
}
|
||||
if (syncContent === 'schema') {
|
||||
setSyncContent('both');
|
||||
}
|
||||
if (targetTableStrategy === 'existing_only') {
|
||||
setTargetTableStrategy('smart');
|
||||
}
|
||||
if (!createIndexes) {
|
||||
setCreateIndexes(true);
|
||||
}
|
||||
} else {
|
||||
if (targetTableStrategy !== 'existing_only') {
|
||||
setTargetTableStrategy('existing_only');
|
||||
}
|
||||
if (createIndexes) {
|
||||
setCreateIndexes(false);
|
||||
}
|
||||
}
|
||||
}, [workflowType]);
|
||||
|
||||
const handleSourceConnChange = async (connId: string) => {
|
||||
setSourceConnId(connId);
|
||||
setSourceDb('');
|
||||
const conn = connections.find(c => c.id === connId);
|
||||
if (conn) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
|
||||
if (res.success) {
|
||||
setSourceDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username));
|
||||
}
|
||||
} catch(e) { message.error("Failed to fetch source databases"); }
|
||||
setLoading(false);
|
||||
}
|
||||
if (conn) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
|
||||
if (res.success) {
|
||||
const dbRows = Array.isArray(res.data) ? res.data : [];
|
||||
setSourceDbs(dbRows
|
||||
.map((r: any) => r?.Database || r?.database || r?.username)
|
||||
.filter((name: any) => typeof name === 'string' && name.trim() !== ''));
|
||||
}
|
||||
} catch(e) { message.error("Failed to fetch source databases"); }
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleTargetConnChange = async (connId: string) => {
|
||||
setTargetConnId(connId);
|
||||
setTargetDb('');
|
||||
const conn = connections.find(c => c.id === connId);
|
||||
if (conn) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
|
||||
if (res.success) {
|
||||
setTargetDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username));
|
||||
}
|
||||
} catch(e) { message.error("Failed to fetch target databases"); }
|
||||
setLoading(false);
|
||||
}
|
||||
if (conn) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
|
||||
if (res.success) {
|
||||
const dbRows = Array.isArray(res.data) ? res.data : [];
|
||||
setTargetDbs(dbRows
|
||||
.map((r: any) => r?.Database || r?.database || r?.username)
|
||||
.filter((name: any) => typeof name === 'string' && name.trim() !== ''));
|
||||
}
|
||||
} catch(e) { message.error("Failed to fetch target databases"); }
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const nextToTables = async () => {
|
||||
@@ -189,14 +359,17 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
try {
|
||||
const conn = connections.find(c => c.id === sourceConnId);
|
||||
if (conn) {
|
||||
const config = normalizeConnConfig(conn, sourceDb);
|
||||
const res = await DBGetTables(config as any, sourceDb);
|
||||
if (res.success) {
|
||||
// DBGetTables returns [{Table: "name"}, ...]
|
||||
const tables = (res.data as any[]).map((row: any) => row.Table || row.table || row.TABLE_NAME || Object.values(row)[0]);
|
||||
setAllTables(tables as string[]);
|
||||
setCurrentStep(1);
|
||||
} else {
|
||||
const config = normalizeConnConfig(conn, sourceDb);
|
||||
const res = await DBGetTables(config as any, sourceDb);
|
||||
if (res.success) {
|
||||
// DBGetTables returns [{Table: "name"}, ...]
|
||||
const tableRows = Array.isArray(res.data) ? res.data : [];
|
||||
const tables = tableRows
|
||||
.map((row: any) => row?.Table || row?.table || row?.TABLE_NAME || Object.values(row || {})[0])
|
||||
.filter((name: any) => typeof name === 'string' && name.trim() !== '');
|
||||
setAllTables(tables as string[]);
|
||||
setCurrentStep(1);
|
||||
} else {
|
||||
message.error(res.message);
|
||||
}
|
||||
}
|
||||
@@ -236,6 +409,9 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
content: syncContent,
|
||||
mode: "insert_update",
|
||||
autoAddColumns,
|
||||
targetTableStrategy,
|
||||
createIndexes,
|
||||
mongoCollectionName: mongoCollectionName.trim(),
|
||||
jobId,
|
||||
};
|
||||
|
||||
@@ -286,6 +462,9 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
content: "data",
|
||||
mode: "insert_update",
|
||||
autoAddColumns,
|
||||
targetTableStrategy,
|
||||
createIndexes,
|
||||
mongoCollectionName: mongoCollectionName.trim(),
|
||||
};
|
||||
|
||||
try {
|
||||
@@ -362,6 +541,9 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
content: syncContent,
|
||||
mode: syncMode,
|
||||
autoAddColumns,
|
||||
targetTableStrategy,
|
||||
createIndexes,
|
||||
mongoCollectionName: mongoCollectionName.trim(),
|
||||
tableOptions,
|
||||
jobId,
|
||||
};
|
||||
@@ -402,10 +584,139 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
);
|
||||
};
|
||||
|
||||
const previewSql = useMemo(() => {
|
||||
if (!previewData || !previewTable) return { sqlText: '', statementCount: 0 };
|
||||
const targetType = String(connections.find(c => c.id === targetConnId)?.config?.type || '');
|
||||
const ops = tableOptions[previewTable] || { insert: true, update: true, delete: false };
|
||||
return buildSqlPreview(previewData, previewTable, targetType, ops);
|
||||
}, [previewData, previewTable, targetConnId, connections, tableOptions]);
|
||||
|
||||
const analysisWarnings = useMemo(() => {
|
||||
const items: string[] = [];
|
||||
diffTables.forEach((table) => {
|
||||
(table.warnings || []).forEach((warning) => items.push(`${table.table}: ${warning}`));
|
||||
(table.unsupportedObjects || []).forEach((warning) => items.push(`${table.table}: ${warning}`));
|
||||
});
|
||||
return Array.from(new Set(items));
|
||||
}, [diffTables]);
|
||||
|
||||
const isMigrationWorkflow = workflowType === 'migration';
|
||||
const sourceConn = useMemo(() => connections.find(c => c.id === sourceConnId), [connections, sourceConnId]);
|
||||
const targetConn = useMemo(() => connections.find(c => c.id === targetConnId), [connections, targetConnId]);
|
||||
const sourceType = String(sourceConn?.config?.type || '').toLowerCase();
|
||||
const targetType = String(targetConn?.config?.type || '').toLowerCase();
|
||||
const isRedisMongoKeyspaceMigration = isMigrationWorkflow && (
|
||||
(sourceType === 'redis' && targetType === 'mongodb') ||
|
||||
(sourceType === 'mongodb' && targetType === 'redis')
|
||||
);
|
||||
const defaultMongoCollectionName = useMemo(() => {
|
||||
if (sourceType === 'redis' && targetType === 'mongodb') {
|
||||
return `redis_db_${resolveRedisDbIndex(sourceDb || sourceConn?.config?.database)}_keys`;
|
||||
}
|
||||
if (sourceType === 'mongodb' && targetType === 'redis') {
|
||||
return selectedTables[0] || `redis_db_${resolveRedisDbIndex(targetDb || targetConn?.config?.database)}_keys`;
|
||||
}
|
||||
return '';
|
||||
}, [sourceType, targetType, sourceDb, targetDb, sourceConn, targetConn, selectedTables]);
|
||||
|
||||
const modalPanelStyle = useMemo(() => ({
|
||||
background: darkMode
|
||||
? 'linear-gradient(180deg, rgba(16,22,34,0.96) 0%, rgba(10,14,24,0.98) 100%)'
|
||||
: 'linear-gradient(180deg, rgba(255,255,255,0.98) 0%, rgba(246,248,252,0.98) 100%)',
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.08)' : '1px solid rgba(16,24,40,0.08)',
|
||||
boxShadow: darkMode ? '0 24px 56px rgba(0,0,0,0.36)' : '0 18px 44px rgba(15,23,42,0.14)',
|
||||
backdropFilter: darkMode ? 'blur(18px)' : 'none',
|
||||
}), [darkMode]);
|
||||
|
||||
const shellCardStyle = useMemo<React.CSSProperties>(() => ({
|
||||
borderRadius: 18,
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.08)' : '1px solid rgba(15,23,42,0.08)',
|
||||
background: darkMode ? 'rgba(255,255,255,0.03)' : `rgba(255,255,255,${Math.max(effectiveOpacity, 0.88)})`,
|
||||
boxShadow: darkMode ? '0 12px 32px rgba(0,0,0,0.22)' : '0 10px 24px rgba(15,23,42,0.08)',
|
||||
overflow: 'hidden',
|
||||
}), [darkMode, effectiveOpacity]);
|
||||
|
||||
const heroPanelStyle = useMemo<React.CSSProperties>(() => ({
|
||||
padding: 18,
|
||||
borderRadius: 18,
|
||||
border: darkMode ? '1px solid rgba(255,214,102,0.12)' : '1px solid rgba(24,144,255,0.12)',
|
||||
background: darkMode
|
||||
? 'linear-gradient(135deg, rgba(255,214,102,0.10) 0%, rgba(255,255,255,0.03) 100%)'
|
||||
: 'linear-gradient(135deg, rgba(24,144,255,0.10) 0%, rgba(255,255,255,0.95) 100%)',
|
||||
marginBottom: 18,
|
||||
}), [darkMode]);
|
||||
|
||||
const badgeStyle = useMemo<React.CSSProperties>(() => ({
|
||||
display: 'inline-flex',
|
||||
alignItems: 'center',
|
||||
gap: 6,
|
||||
padding: '6px 10px',
|
||||
borderRadius: 999,
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.10)' : '1px solid rgba(15,23,42,0.08)',
|
||||
background: darkMode ? 'rgba(255,255,255,0.04)' : 'rgba(255,255,255,0.86)',
|
||||
color: darkMode ? 'rgba(255,255,255,0.88)' : '#334155',
|
||||
fontSize: 12,
|
||||
fontWeight: 600,
|
||||
}), [darkMode]);
|
||||
|
||||
const quietPanelStyle = useMemo<React.CSSProperties>(() => ({
|
||||
padding: 14,
|
||||
borderRadius: 16,
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.08)' : '1px solid rgba(15,23,42,0.08)',
|
||||
background: darkMode ? 'rgba(255,255,255,0.025)' : 'rgba(248,250,252,0.92)',
|
||||
}), [darkMode]);
|
||||
|
||||
const modalWorkspaceStyle = useMemo<React.CSSProperties>(() => ({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
height: '100%',
|
||||
minHeight: 0,
|
||||
}), []);
|
||||
|
||||
const modalScrollableContentStyle = useMemo<React.CSSProperties>(() => ({
|
||||
flex: 1,
|
||||
minHeight: 0,
|
||||
overflowY: 'auto',
|
||||
overflowX: 'hidden',
|
||||
paddingRight: 4,
|
||||
overscrollBehavior: 'contain',
|
||||
}), []);
|
||||
|
||||
const modalFooterBarStyle = useMemo<React.CSSProperties>(() => ({
|
||||
marginTop: 18,
|
||||
display: 'flex',
|
||||
justifyContent: 'flex-end',
|
||||
gap: 8,
|
||||
paddingTop: 12,
|
||||
borderTop: darkMode ? '1px solid rgba(255,255,255,0.06)' : '1px solid rgba(15,23,42,0.06)',
|
||||
flex: '0 0 auto',
|
||||
}), [darkMode]);
|
||||
|
||||
const renderModalTitle = (title: string, description: string) => (
|
||||
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 12 }}>
|
||||
<div style={{
|
||||
width: 38,
|
||||
height: 38,
|
||||
borderRadius: 14,
|
||||
display: 'grid',
|
||||
placeItems: 'center',
|
||||
background: darkMode ? 'rgba(255,214,102,0.12)' : 'rgba(24,144,255,0.10)',
|
||||
color: darkMode ? '#ffd666' : token.colorPrimary,
|
||||
flexShrink: 0,
|
||||
}}>
|
||||
{isMigrationWorkflow ? <RocketOutlined /> : <SwapOutlined />}
|
||||
</div>
|
||||
<div style={{ minWidth: 0 }}>
|
||||
<div style={{ fontSize: 16, fontWeight: 700, color: darkMode ? '#f8fafc' : '#0f172a' }}>{title}</div>
|
||||
<div style={{ marginTop: 4, fontSize: 12, lineHeight: 1.6, color: darkMode ? 'rgba(255,255,255,0.56)' : 'rgba(15,23,42,0.58)' }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Modal
|
||||
title="数据同步"
|
||||
title={renderModalTitle(isMigrationWorkflow ? '跨库迁移工作台' : '数据同步工作台', isMigrationWorkflow ? '按源库 → 目标库完成建表、导入与风险预检。' : '按已有目标表完成差异对比、同步执行与结果确认。')}
|
||||
open={open}
|
||||
onCancel={() => {
|
||||
if (syncing) {
|
||||
@@ -414,23 +725,61 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
}
|
||||
onClose();
|
||||
}}
|
||||
width={800}
|
||||
width={920}
|
||||
footer={null}
|
||||
destroyOnHidden
|
||||
closable={!syncing}
|
||||
maskClosable={!syncing}
|
||||
styles={{
|
||||
content: modalPanelStyle,
|
||||
header: { background: 'transparent', borderBottom: 'none', paddingBottom: 10 },
|
||||
body: {
|
||||
paddingTop: 8,
|
||||
height: 760,
|
||||
maxHeight: 'calc(100vh - 120px)',
|
||||
overflow: 'hidden',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
},
|
||||
footer: { background: 'transparent', borderTop: 'none', paddingTop: 12 },
|
||||
}}
|
||||
>
|
||||
<div style={modalWorkspaceStyle}>
|
||||
<div style={{ flex: '0 0 auto' }}>
|
||||
<div style={heroPanelStyle}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', gap: 12, alignItems: 'flex-start', flexWrap: 'wrap' }}>
|
||||
<div style={{ minWidth: 0 }}>
|
||||
<div style={{ fontSize: 18, fontWeight: 700, color: darkMode ? '#f8fafc' : '#0f172a' }}>{isMigrationWorkflow ? '跨数据源迁移' : '数据同步'}</div>
|
||||
<div style={{ marginTop: 6, fontSize: 13, lineHeight: 1.7, color: darkMode ? 'rgba(255,255,255,0.62)' : 'rgba(15,23,42,0.62)' }}>
|
||||
{isMigrationWorkflow
|
||||
? '适合把源表迁移到另一套数据库,可按策略自动建表、导入数据并补建可兼容索引。'
|
||||
: '适合目标表已存在的场景,先做差异分析,再按勾选执行插入、更新或删除。'}
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 8 }}>
|
||||
<span style={badgeStyle}>{isMigrationWorkflow ? <RocketOutlined /> : <SwapOutlined />} {isMigrationWorkflow ? '迁移模式' : '同步模式'}</span>
|
||||
<span style={badgeStyle}><DatabaseOutlined /> {sourceConnId ? '已选源连接' : '待选源连接'}</span>
|
||||
<span style={badgeStyle}><TableOutlined /> {selectedTables.length || 0} 张表</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Steps current={currentStep} style={{ marginBottom: 24 }}>
|
||||
<Step title="配置源与目标" />
|
||||
<Step title="选择表" />
|
||||
<Step title="执行结果" />
|
||||
</Steps>
|
||||
</div>
|
||||
|
||||
<div style={modalScrollableContentStyle}>
|
||||
{/* STEP 1: CONFIG */}
|
||||
{currentStep === 0 && (
|
||||
<div>
|
||||
<div style={{ display: 'flex', gap: 24, justifyContent: 'center' }}>
|
||||
<Card title="源数据库" style={{ width: 350 }}>
|
||||
<div style={{ display: 'grid', gridTemplateColumns: 'minmax(0, 1fr) 44px minmax(0, 1fr)', gap: 18, alignItems: 'stretch' }}>
|
||||
<Card
|
||||
title="源数据库"
|
||||
style={shellCardStyle}
|
||||
styles={{ header: { borderBottom: darkMode ? '1px solid rgba(255,255,255,0.08)' : '1px solid rgba(15,23,42,0.06)', fontWeight: 700 }, body: { padding: 18 } }}
|
||||
>
|
||||
<Form layout="vertical">
|
||||
<Form.Item label="连接">
|
||||
<Select value={sourceConnId} onChange={handleSourceConnChange}>
|
||||
@@ -444,8 +793,16 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Card>
|
||||
<div style={{ display: 'flex', alignItems: 'center' }}>至</div>
|
||||
<Card title="目标数据库" style={{ width: 350 }}>
|
||||
<div style={{ display: 'grid', placeItems: 'center' }}>
|
||||
<div style={{ ...badgeStyle, width: 44, height: 44, borderRadius: 14, justifyContent: 'center', padding: 0 }}>
|
||||
<SwapOutlined />
|
||||
</div>
|
||||
</div>
|
||||
<Card
|
||||
title="目标数据库"
|
||||
style={shellCardStyle}
|
||||
styles={{ header: { borderBottom: darkMode ? '1px solid rgba(255,255,255,0.08)' : '1px solid rgba(15,23,42,0.06)', fontWeight: 700 }, body: { padding: 18 } }}
|
||||
>
|
||||
<Form layout="vertical">
|
||||
<Form.Item label="连接">
|
||||
<Select value={targetConnId} onChange={handleTargetConnChange}>
|
||||
@@ -461,27 +818,94 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<Card title="同步选项" style={{ marginTop: 16 }}>
|
||||
<Card
|
||||
title={isMigrationWorkflow ? '迁移选项' : '同步选项'}
|
||||
style={{ ...shellCardStyle, marginTop: 18 }}
|
||||
styles={{ header: { borderBottom: darkMode ? '1px solid rgba(255,255,255,0.08)' : '1px solid rgba(15,23,42,0.06)', fontWeight: 700 }, body: { padding: 18 } }}
|
||||
>
|
||||
<div style={{ ...quietPanelStyle, marginBottom: 14 }}>
|
||||
<Text style={{ color: darkMode ? 'rgba(255,255,255,0.72)' : 'rgba(15,23,42,0.68)', lineHeight: 1.7 }}>
|
||||
先明确当前要做的是“已有目标表同步”还是“跨库迁移”,页面会按功能类型自动给出更安全的默认策略。
|
||||
</Text>
|
||||
</div>
|
||||
<Form layout="vertical">
|
||||
<Form.Item label="同步内容">
|
||||
<Form.Item label="功能类型">
|
||||
<Select value={workflowType} onChange={setWorkflowType}>
|
||||
<Option value="sync">数据同步(基于已有目标表做差异同步)</Option>
|
||||
<Option value="migration">跨库迁移(可自动建表后导入)</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Alert
|
||||
type={isMigrationWorkflow ? 'info' : 'success'}
|
||||
showIcon
|
||||
style={{ marginBottom: 12 }}
|
||||
message={isMigrationWorkflow
|
||||
? '当前为“跨库迁移”模式:适合将表迁移到另一数据源,可自动建表并导入数据。'
|
||||
: '当前为“数据同步”模式:适合目标表已存在时做增量同步或覆盖导入。'}
|
||||
/>
|
||||
<Form.Item label={isMigrationWorkflow ? '迁移内容' : '同步内容'}>
|
||||
<Select value={syncContent} onChange={setSyncContent}>
|
||||
<Option value="data">仅同步数据</Option>
|
||||
<Option value="schema">仅同步结构</Option>
|
||||
<Option value="both">同步结构 + 数据</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item label="同步模式">
|
||||
<Form.Item label={isMigrationWorkflow ? '迁移模式' : '同步模式'}>
|
||||
<Select value={syncMode} onChange={setSyncMode} disabled={syncContent === 'schema'}>
|
||||
<Option value="insert_update">增量同步(对比差异,按插入/更新/删除勾选执行)</Option>
|
||||
<Option value="insert_only">仅插入(不对比目标;无主键表将跳过)</Option>
|
||||
<Option value="full_overwrite">全量覆盖(清空目标表后插入)</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item label={isMigrationWorkflow ? '目标表处理策略' : '目标表要求'}>
|
||||
<Select value={targetTableStrategy} onChange={setTargetTableStrategy} disabled={!isMigrationWorkflow}>
|
||||
<Option value="existing_only">仅使用已有目标表</Option>
|
||||
<Option value="auto_create_if_missing">目标表不存在时自动建表后导入</Option>
|
||||
<Option value="smart">智能模式(存在则直接导入,不存在则自动建表)</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
{isRedisMongoKeyspaceMigration && (
|
||||
<Form.Item
|
||||
label="Mongo 集合名(可选)"
|
||||
extra={sourceType === 'redis'
|
||||
? '为空时沿用默认集合名;填写后本次 Redis 键空间会统一写入该 Mongo 集合。'
|
||||
: 'MongoDB → Redis 场景下通常直接选择源集合;这里留空即可,未显式选集合时才会回退使用该名称。'}
|
||||
>
|
||||
<Input
|
||||
value={mongoCollectionName}
|
||||
onChange={(e) => setMongoCollectionName(e.target.value)}
|
||||
placeholder={defaultMongoCollectionName || '请输入 Mongo 集合名'}
|
||||
allowClear
|
||||
maxLength={128}
|
||||
/>
|
||||
</Form.Item>
|
||||
)}
|
||||
<Form.Item>
|
||||
<Checkbox checked={autoAddColumns} onChange={(e) => setAutoAddColumns(e.target.checked)}>
|
||||
自动补齐目标表缺失字段(仅 MySQL 目标)
|
||||
自动补齐目标表缺失字段(当前支持 MySQL 目标及 MySQL → Kingbase)
|
||||
</Checkbox>
|
||||
</Form.Item>
|
||||
<Form.Item>
|
||||
<Checkbox checked={createIndexes} onChange={(e) => setCreateIndexes(e.target.checked)} disabled={!isMigrationWorkflow || targetTableStrategy === 'existing_only'}>
|
||||
自动迁移可兼容的普通索引/唯一索引(仅自动建表模式生效)
|
||||
</Checkbox>
|
||||
</Form.Item>
|
||||
{isMigrationWorkflow && targetTableStrategy !== 'existing_only' && (
|
||||
<Alert
|
||||
type="info"
|
||||
showIcon
|
||||
message="自动建表模式首期仅支持 MySQL → Kingbase;将迁移字段、主键、普通/唯一/联合索引,并显式跳过全文、空间、前缀、函数类索引。"
|
||||
style={{ marginBottom: 12 }}
|
||||
/>
|
||||
)}
|
||||
{!isMigrationWorkflow && (
|
||||
<Alert
|
||||
type="info"
|
||||
showIcon
|
||||
message="数据同步模式默认基于已有目标表执行;如需跨数据源建表导入,请切换到“跨库迁移”。"
|
||||
style={{ marginBottom: 12 }}
|
||||
/>
|
||||
)}
|
||||
{syncContent !== 'schema' && syncMode === 'full_overwrite' && (
|
||||
<Alert
|
||||
type="warning"
|
||||
@@ -496,26 +920,42 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
|
||||
{/* STEP 2: TABLES */}
|
||||
{currentStep === 1 && (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Text type="secondary">请选择需要同步的表:</Text>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 14 }}>
|
||||
<div style={quietPanelStyle}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 10 }}>
|
||||
<Text type="secondary">请选择需要同步的表:</Text>
|
||||
<Checkbox checked={showSameTables} onChange={(e) => setShowSameTables(e.target.checked)}>
|
||||
显示相同表
|
||||
</Checkbox>
|
||||
</div>
|
||||
<Transfer
|
||||
</div>
|
||||
<Transfer
|
||||
dataSource={allTables.map(t => ({ key: t, title: t }))}
|
||||
titles={['源表', '已选表']}
|
||||
targetKeys={selectedTables}
|
||||
onChange={(keys) => setSelectedTables(keys as string[])}
|
||||
render={item => item.title}
|
||||
listStyle={{ width: 350, height: 280, marginTop: 0 }}
|
||||
locale={{ itemUnit: '项', itemsUnit: '项', searchPlaceholder: '搜索表', notFoundContent: '暂无数据' }}
|
||||
listStyle={{ width: 390, height: 320, marginTop: 0, borderRadius: 14, overflow: 'hidden' }}
|
||||
locale={{ itemUnit: '项', itemsUnit: '项', searchPlaceholder: '搜索表…', notFoundContent: '暂无数据' }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{diffTables.length > 0 && (
|
||||
<div>
|
||||
<Divider orientation="left">对比结果</Divider>
|
||||
<div style={quietPanelStyle}>
|
||||
<Divider orientation="left" style={{ marginTop: 0 }}>对比结果</Divider>
|
||||
{analysisWarnings.length > 0 && (
|
||||
<Alert
|
||||
type="warning"
|
||||
showIcon
|
||||
message="预检发现风险或降级项,请在执行前确认"
|
||||
description={
|
||||
<ul style={{ margin: 0, paddingLeft: 18 }}>
|
||||
{analysisWarnings.slice(0, 8).map((item) => <li key={item}>{item}</li>)}
|
||||
{analysisWarnings.length > 8 && <li>还有 {analysisWarnings.length - 8} 项未展开</li>}
|
||||
</ul>
|
||||
}
|
||||
style={{ marginBottom: 12 }}
|
||||
/>
|
||||
)}
|
||||
<Table
|
||||
size="small"
|
||||
pagination={false}
|
||||
@@ -527,13 +967,29 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const same = Number(t.same || 0);
|
||||
const msg = String(t.message || '').trim();
|
||||
const can = !!t.canSync;
|
||||
const warns = Array.isArray(t.warnings) ? t.warnings.length : 0;
|
||||
const unsupported = Array.isArray(t.unsupportedObjects) ? t.unsupportedObjects.length : 0;
|
||||
if (showSameTables) return true;
|
||||
if (!can) return true;
|
||||
if (msg) return true;
|
||||
if (msg || warns > 0 || unsupported > 0) return true;
|
||||
return ins > 0 || upd > 0 || del > 0 || same === 0;
|
||||
})}
|
||||
columns={[
|
||||
{ title: '表名', dataIndex: 'table', key: 'table', ellipsis: true },
|
||||
{
|
||||
title: '目标表',
|
||||
key: 'targetTableExists',
|
||||
width: 90,
|
||||
render: (_: any, r: any) => r.targetTableExists ? '已存在' : '不存在'
|
||||
},
|
||||
{
|
||||
title: '计划',
|
||||
dataIndex: 'plannedAction',
|
||||
key: 'plannedAction',
|
||||
width: 220,
|
||||
ellipsis: true,
|
||||
render: (v: any) => String(v || '')
|
||||
},
|
||||
{
|
||||
title: '插入',
|
||||
key: 'inserts',
|
||||
@@ -542,11 +998,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
|
||||
const disabled = !r.canSync || analyzing || Number(r.inserts || 0) === 0;
|
||||
return (
|
||||
<Checkbox
|
||||
checked={!!ops.insert}
|
||||
disabled={disabled}
|
||||
onChange={(e) => updateTableOption(r.table, 'insert', e.target.checked)}
|
||||
>
|
||||
<Checkbox checked={!!ops.insert} disabled={disabled} onChange={(e) => updateTableOption(r.table, 'insert', e.target.checked)}>
|
||||
{Number(r.inserts || 0)}
|
||||
</Checkbox>
|
||||
);
|
||||
@@ -560,11 +1012,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
|
||||
const disabled = !r.canSync || analyzing || Number(r.updates || 0) === 0;
|
||||
return (
|
||||
<Checkbox
|
||||
checked={!!ops.update}
|
||||
disabled={disabled}
|
||||
onChange={(e) => updateTableOption(r.table, 'update', e.target.checked)}
|
||||
>
|
||||
<Checkbox checked={!!ops.update} disabled={disabled} onChange={(e) => updateTableOption(r.table, 'update', e.target.checked)}>
|
||||
{Number(r.updates || 0)}
|
||||
</Checkbox>
|
||||
);
|
||||
@@ -578,18 +1026,28 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
|
||||
const disabled = !r.canSync || analyzing || Number(r.deletes || 0) === 0;
|
||||
return (
|
||||
<Checkbox
|
||||
checked={!!ops.delete}
|
||||
disabled={disabled}
|
||||
onChange={(e) => updateTableOption(r.table, 'delete', e.target.checked)}
|
||||
>
|
||||
<Checkbox checked={!!ops.delete} disabled={disabled} onChange={(e) => updateTableOption(r.table, 'delete', e.target.checked)}>
|
||||
{Number(r.deletes || 0)}
|
||||
</Checkbox>
|
||||
);
|
||||
}
|
||||
},
|
||||
{ title: '相同', dataIndex: 'same', key: 'same', width: 70, render: (v: any) => Number(v || 0) },
|
||||
{ title: '消息', dataIndex: 'message', key: 'message', ellipsis: true, render: (v: any) => (v ? String(v) : '') },
|
||||
{
|
||||
title: '风险',
|
||||
key: 'warnings',
|
||||
width: 220,
|
||||
render: (_: any, r: any) => {
|
||||
const warns = [...(Array.isArray(r.warnings) ? r.warnings : []), ...(Array.isArray(r.unsupportedObjects) ? r.unsupportedObjects : [])];
|
||||
if (warns.length === 0) return '-';
|
||||
return (
|
||||
<div style={{ color: '#d48806', fontSize: 12, lineHeight: 1.5 }}>
|
||||
{warns.slice(0, 2).map((item: string) => <div key={item}>{item}</div>)}
|
||||
{warns.length > 2 && <div>还有 {warns.length - 2} 项</div>}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '预览',
|
||||
key: 'preview',
|
||||
@@ -613,7 +1071,8 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
|
||||
{/* STEP 3: RESULT */}
|
||||
{currentStep === 2 && (
|
||||
<div>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 14 }}>
|
||||
<div style={quietPanelStyle}>
|
||||
<Alert
|
||||
message={syncing ? "正在同步" : (syncResult?.success ? "同步完成" : "同步失败")}
|
||||
description={
|
||||
@@ -625,7 +1084,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
showIcon
|
||||
/>
|
||||
|
||||
<div style={{ marginTop: 12 }}>
|
||||
<div style={{ marginTop: 14 }}>
|
||||
<Progress
|
||||
percent={syncProgress.percent}
|
||||
status={syncing ? "active" : (syncResult?.success ? "success" : "exception")}
|
||||
@@ -633,7 +1092,9 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Divider orientation="left">日志</Divider>
|
||||
</div>
|
||||
<div style={quietPanelStyle}>
|
||||
<Divider orientation="left" style={{ marginTop: 0 }}>执行日志</Divider>
|
||||
<div
|
||||
ref={logBoxRef}
|
||||
onScroll={() => {
|
||||
@@ -642,14 +1103,25 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const nearBottom = el.scrollHeight - el.scrollTop - el.clientHeight < 40;
|
||||
autoScrollRef.current = nearBottom;
|
||||
}}
|
||||
style={{ background: '#f5f5f5', padding: 12, height: 300, overflowY: 'auto', fontFamily: 'monospace' }}
|
||||
style={{
|
||||
background: darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(248,250,252,0.92)',
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.08)' : '1px solid rgba(15,23,42,0.06)',
|
||||
borderRadius: 14,
|
||||
padding: 12,
|
||||
height: 300,
|
||||
overflowY: 'auto',
|
||||
fontFamily: 'SFMono-Regular, ui-monospace, Menlo, Consolas, monospace'
|
||||
}}
|
||||
>
|
||||
{syncLogs.map((item, i: number) => <div key={i}>{renderSyncLogItem(item)}</div>)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div style={{ marginTop: 24, textAlign: 'right' }}>
|
||||
</div>
|
||||
|
||||
<div style={modalFooterBarStyle}>
|
||||
{currentStep === 0 && (
|
||||
<Button type="primary" onClick={nextToTables} loading={loading}>下一步</Button>
|
||||
)}
|
||||
@@ -676,14 +1148,16 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
<Drawer
|
||||
title={`差异预览:${previewTable}`}
|
||||
styles={{ body: { background: darkMode ? 'rgba(9,13,20,0.98)' : '#f8fafc' } }}
|
||||
open={previewOpen}
|
||||
onClose={() => { setPreviewOpen(false); setPreviewTable(''); setPreviewData(null); }}
|
||||
width={900}
|
||||
>
|
||||
{previewLoading && <Alert type="info" showIcon message="正在加载差异预览..." />}
|
||||
{previewLoading && <Alert type="info" showIcon message="正在加载差异预览…" />}
|
||||
{!previewLoading && previewData && (
|
||||
<div>
|
||||
<Alert
|
||||
@@ -794,6 +1268,51 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
{
|
||||
key: 'sql',
|
||||
label: `SQL(${previewSql.statementCount})`,
|
||||
children: (
|
||||
<div>
|
||||
<Alert
|
||||
type="info"
|
||||
showIcon
|
||||
message="SQL 预览会按当前勾选的插入/更新/删除与行选择范围生成,用于审核确认。"
|
||||
/>
|
||||
<div style={{ marginTop: 8, marginBottom: 8, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Text type="secondary">共 {previewSql.statementCount} 条语句(预览数据最多 200 条/类型)</Text>
|
||||
<Button
|
||||
size="small"
|
||||
disabled={!previewSql.sqlText}
|
||||
onClick={async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(previewSql.sqlText || '');
|
||||
message.success('SQL 已复制');
|
||||
} catch {
|
||||
message.error('复制失败,请手动复制');
|
||||
}
|
||||
}}
|
||||
>
|
||||
复制 SQL
|
||||
</Button>
|
||||
</div>
|
||||
<pre
|
||||
style={{
|
||||
margin: 0,
|
||||
padding: 10,
|
||||
border: '1px solid #f0f0f0',
|
||||
borderRadius: 6,
|
||||
background: '#fafafa',
|
||||
maxHeight: 420,
|
||||
overflow: 'auto',
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordBreak: 'break-word'
|
||||
}}
|
||||
>
|
||||
{previewSql.sqlText || '-- 当前勾选范围下无 SQL 可预览'}
|
||||
</pre>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
]}
|
||||
/>
|
||||
|
||||
@@ -4,7 +4,7 @@ import { TabData, ColumnDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { buildOrderBySQL, buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
|
||||
import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
|
||||
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
|
||||
@@ -155,6 +155,16 @@ const reverseOrderBySQL = (orderBySQL: string): string => {
|
||||
type ViewerFilterSnapshot = {
|
||||
showFilter: boolean;
|
||||
conditions: FilterCondition[];
|
||||
currentPage: number;
|
||||
pageSize: number;
|
||||
sortInfo: { columnKey: string, order: string } | null;
|
||||
scrollTop: number;
|
||||
scrollLeft: number;
|
||||
};
|
||||
|
||||
type ViewerScrollSnapshot = {
|
||||
top: number;
|
||||
left: number;
|
||||
};
|
||||
|
||||
const viewerFilterSnapshotsByTab = new Map<string, ViewerFilterSnapshot>();
|
||||
@@ -175,15 +185,23 @@ const normalizeViewerFilterConditions = (conditions: FilterCondition[] | undefin
|
||||
const getViewerFilterSnapshot = (tabId: string): ViewerFilterSnapshot => {
|
||||
const cached = viewerFilterSnapshotsByTab.get(String(tabId || '').trim());
|
||||
if (!cached) {
|
||||
return { showFilter: false, conditions: [] };
|
||||
return { showFilter: false, conditions: [], currentPage: 1, pageSize: 100, sortInfo: null, scrollTop: 0, scrollLeft: 0 };
|
||||
}
|
||||
return {
|
||||
showFilter: cached.showFilter === true,
|
||||
conditions: normalizeViewerFilterConditions(cached.conditions),
|
||||
currentPage: Number.isFinite(Number(cached.currentPage)) && Number(cached.currentPage) > 0 ? Number(cached.currentPage) : 1,
|
||||
pageSize: Number.isFinite(Number(cached.pageSize)) && Number(cached.pageSize) > 0 ? Number(cached.pageSize) : 100,
|
||||
sortInfo: cached.sortInfo && cached.sortInfo.columnKey && (cached.sortInfo.order === 'ascend' || cached.sortInfo.order === 'descend')
|
||||
? { columnKey: String(cached.sortInfo.columnKey), order: cached.sortInfo.order }
|
||||
: null,
|
||||
scrollTop: Number.isFinite(Number(cached.scrollTop)) ? Number(cached.scrollTop) : 0,
|
||||
scrollLeft: Number.isFinite(Number(cached.scrollLeft)) ? Number(cached.scrollLeft) : 0,
|
||||
};
|
||||
};
|
||||
|
||||
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const initialViewerSnapshot = useMemo(() => getViewerFilterSnapshot(tab.id), [tab.id]);
|
||||
const [data, setData] = useState<any[]>([]);
|
||||
const [columnNames, setColumnNames] = useState<string[]>([]);
|
||||
const [pkColumns, setPkColumns] = useState<string[]>([]);
|
||||
@@ -204,10 +222,15 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const latestDbNameRef = useRef<string>('');
|
||||
const latestCountSqlRef = useRef<string>('');
|
||||
const latestCountKeyRef = useRef<string>('');
|
||||
const scrollSnapshotRef = useRef<ViewerScrollSnapshot>({
|
||||
top: initialViewerSnapshot.scrollTop,
|
||||
left: initialViewerSnapshot.scrollLeft,
|
||||
});
|
||||
const initialLoadRef = useRef(false);
|
||||
|
||||
const [pagination, setPagination] = useState<ViewerPaginationState>({
|
||||
current: 1,
|
||||
pageSize: 100,
|
||||
current: initialViewerSnapshot.currentPage,
|
||||
pageSize: initialViewerSnapshot.pageSize,
|
||||
total: 0,
|
||||
totalKnown: false,
|
||||
totalApprox: false,
|
||||
@@ -215,30 +238,51 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
totalCountCancelled: false,
|
||||
});
|
||||
|
||||
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
|
||||
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(initialViewerSnapshot.sortInfo);
|
||||
|
||||
const [showFilter, setShowFilter] = useState<boolean>(() => getViewerFilterSnapshot(tab.id).showFilter);
|
||||
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>(() => getViewerFilterSnapshot(tab.id).conditions);
|
||||
const [showFilter, setShowFilter] = useState<boolean>(initialViewerSnapshot.showFilter);
|
||||
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>(initialViewerSnapshot.conditions);
|
||||
const duckdbSafeSelectCacheRef = useRef<Record<string, string>>({});
|
||||
const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config;
|
||||
const currentConnCaps = getDataSourceCapabilities(currentConnConfig);
|
||||
const currentConnType = currentConnCaps.type;
|
||||
const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult;
|
||||
const persistViewerSnapshot = useCallback((tabId: string, overrides?: Partial<ViewerFilterSnapshot>) => {
|
||||
const normalizedTabId = String(tabId || '').trim();
|
||||
if (!normalizedTabId) return;
|
||||
viewerFilterSnapshotsByTab.set(normalizedTabId, {
|
||||
showFilter,
|
||||
conditions: normalizeViewerFilterConditions(filterConditions),
|
||||
currentPage: pagination.current,
|
||||
pageSize: pagination.pageSize,
|
||||
sortInfo,
|
||||
scrollTop: scrollSnapshotRef.current.top,
|
||||
scrollLeft: scrollSnapshotRef.current.left,
|
||||
...overrides,
|
||||
});
|
||||
}, [showFilter, filterConditions, pagination.current, pagination.pageSize, sortInfo]);
|
||||
|
||||
useEffect(() => {
|
||||
const snapshot = getViewerFilterSnapshot(tab.id);
|
||||
setShowFilter(snapshot.showFilter);
|
||||
setFilterConditions(snapshot.conditions);
|
||||
setSortInfo(snapshot.sortInfo);
|
||||
scrollSnapshotRef.current = { top: snapshot.scrollTop, left: snapshot.scrollLeft };
|
||||
initialLoadRef.current = false;
|
||||
}, [tab.id]);
|
||||
|
||||
useEffect(() => {
|
||||
viewerFilterSnapshotsByTab.set(tab.id, {
|
||||
showFilter,
|
||||
conditions: normalizeViewerFilterConditions(filterConditions),
|
||||
});
|
||||
}, [tab.id, showFilter, filterConditions]);
|
||||
persistViewerSnapshot(tab.id);
|
||||
}, [tab.id, persistViewerSnapshot]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
persistViewerSnapshot(tab.id);
|
||||
};
|
||||
}, [tab.id, persistViewerSnapshot]);
|
||||
|
||||
useEffect(() => {
|
||||
const snapshot = getViewerFilterSnapshot(tab.id);
|
||||
setPkColumns([]);
|
||||
pkKeyRef.current = '';
|
||||
countKeyRef.current = '';
|
||||
@@ -250,16 +294,27 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
latestDbNameRef.current = '';
|
||||
latestCountSqlRef.current = '';
|
||||
latestCountKeyRef.current = '';
|
||||
scrollSnapshotRef.current = { top: snapshot.scrollTop, left: snapshot.scrollLeft };
|
||||
initialLoadRef.current = false;
|
||||
setPagination(prev => ({
|
||||
...prev,
|
||||
current: 1,
|
||||
current: snapshot.currentPage,
|
||||
pageSize: snapshot.pageSize,
|
||||
total: 0,
|
||||
totalKnown: false,
|
||||
totalApprox: false,
|
||||
totalCountLoading: false,
|
||||
totalCountCancelled: false,
|
||||
}));
|
||||
}, [tab.connectionId, tab.dbName, tab.tableName]);
|
||||
}, [tab.id, tab.connectionId, tab.dbName, tab.tableName]);
|
||||
|
||||
const handleTableScrollSnapshotChange = useCallback((snapshot: ViewerScrollSnapshot) => {
|
||||
scrollSnapshotRef.current = snapshot;
|
||||
persistViewerSnapshot(tab.id, {
|
||||
scrollTop: snapshot.top,
|
||||
scrollLeft: snapshot.left,
|
||||
});
|
||||
}, [tab.id, persistViewerSnapshot]);
|
||||
|
||||
const handleDuckDBManualCount = useCallback(async () => {
|
||||
if (latestDbTypeRef.current !== 'duckdb') {
|
||||
@@ -410,7 +465,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
if (pageRowCount > 0) {
|
||||
const tailOffset = Math.max(0, totalRows - (offset + pageRowCount));
|
||||
if (tailOffset < offset) {
|
||||
sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`;
|
||||
sql = buildPaginatedSelectSQL(dbType, baseSql, reverseOrderSQL, pageRowCount, tailOffset);
|
||||
useClickHouseReversePagination = true;
|
||||
clickHouseReverseLimit = pageRowCount;
|
||||
clickHouseReverseHasMore = currentPage < totalPages;
|
||||
@@ -419,7 +474,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
if (!useClickHouseReversePagination) {
|
||||
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
|
||||
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
|
||||
sql = buildPaginatedSelectSQL(dbType, baseSql, orderBySQL, size + 1, offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,8 +544,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
if (safeSelect) {
|
||||
let fallbackSql = `SELECT ${safeSelect} FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
fallbackSql += buildOrderBySQL(dbType, sortInfo, pkColumns);
|
||||
fallbackSql += ` LIMIT ${size + 1} OFFSET ${offset}`;
|
||||
fallbackSql = buildPaginatedSelectSQL(dbType, fallbackSql, buildOrderBySQL(dbType, sortInfo, pkColumns), size + 1, offset);
|
||||
executedSql = fallbackSql;
|
||||
resData = await executeDataQuery(fallbackSql, '复杂类型降级重试');
|
||||
}
|
||||
@@ -765,8 +819,13 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}, [tab.tableName, currentConnConfig?.type, filterConditions, sortInfo, pkColumns]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData(1, pagination.pageSize);
|
||||
}, [tab, sortInfo, filterConditions]); // Initial load and re-load on sort/filter
|
||||
if (!initialLoadRef.current) {
|
||||
initialLoadRef.current = true;
|
||||
fetchData(pagination.current, pagination.pageSize);
|
||||
return;
|
||||
}
|
||||
fetchData(1, pagination.pageSize);
|
||||
}, [tab.id, tab.connectionId, tab.dbName, tab.tableName, sortInfo, filterConditions]); // Initial load and re-load on sort/filter
|
||||
|
||||
return (
|
||||
<div style={{ flex: '1 1 auto', minHeight: 0, minWidth: 0, height: '100%', width: '100%', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
@@ -792,6 +851,8 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
readOnly={forceReadOnly}
|
||||
sortInfoExternal={sortInfo}
|
||||
exportSqlWithFilter={exportSqlWithFilter || undefined}
|
||||
scrollSnapshot={scrollSnapshotRef.current}
|
||||
onScrollSnapshotChange={handleTableScrollSnapshotChange}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Alert, Button, Collapse, Input, Modal, Progress, Select, Space, Switch,
|
||||
import { DeleteOutlined, DownloadOutlined, FileSearchOutlined, FolderOpenOutlined, InfoCircleFilled, ReloadOutlined } from '@ant-design/icons';
|
||||
import { EventsOn } from '../../wailsjs/runtime/runtime';
|
||||
import { useStore } from '../store';
|
||||
import { normalizeOpacityForPlatform } from '../utils/appearance';
|
||||
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
import {
|
||||
CheckDriverNetworkStatus,
|
||||
DownloadDriverPackage,
|
||||
@@ -166,7 +166,8 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
const theme = useStore((state) => state.theme);
|
||||
const appearance = useStore((state) => state.appearance);
|
||||
const darkMode = theme === 'dark';
|
||||
const opacity = normalizeOpacityForPlatform(appearance.opacity);
|
||||
const resolvedAppearance = resolveAppearanceValues(appearance);
|
||||
const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity);
|
||||
const modalContentRef = useRef<HTMLDivElement | null>(null);
|
||||
const tableContainerRef = useRef<HTMLDivElement | null>(null);
|
||||
const tableScrollTargetsRef = useRef<HTMLElement[]>([]);
|
||||
@@ -846,7 +847,7 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
const installDriverFromLocalFile = useCallback(async (row: DriverStatusRow) => {
|
||||
const fileRes = await SelectDriverPackageFile(downloadDir);
|
||||
if (!fileRes?.success) {
|
||||
if (String(fileRes?.message || '') !== 'Cancelled') {
|
||||
if (String(fileRes?.message || '') !== '已取消') {
|
||||
message.error(fileRes?.message || '选择本地驱动包文件失败');
|
||||
}
|
||||
return;
|
||||
@@ -862,7 +863,7 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
const installDriversFromDirectory = useCallback(async () => {
|
||||
const directoryRes = await SelectDriverPackageDirectory(downloadDir);
|
||||
if (!directoryRes?.success) {
|
||||
if (String(directoryRes?.message || '') !== 'Cancelled') {
|
||||
if (String(directoryRes?.message || '') !== '已取消') {
|
||||
message.error(directoryRes?.message || '选择本地驱动包目录失败');
|
||||
}
|
||||
return;
|
||||
@@ -1223,7 +1224,7 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
paddingRight: 18,
|
||||
},
|
||||
}}
|
||||
destroyOnClose
|
||||
destroyOnHidden
|
||||
footer={(
|
||||
<div className="driver-manager-footer">
|
||||
<div
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { useRef, useEffect } from 'react';
|
||||
import { Table, Tag, Button, Tooltip } from 'antd';
|
||||
import { ClearOutlined, CloseOutlined, CaretRightOutlined, BugOutlined } from '@ant-design/icons';
|
||||
import { Table, Tag, Button, Tooltip, Empty } from 'antd';
|
||||
import { ClearOutlined, CloseOutlined, BugOutlined, ClockCircleOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { normalizeOpacityForPlatform } from '../utils/appearance';
|
||||
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
|
||||
interface LogPanelProps {
|
||||
height: number;
|
||||
@@ -16,7 +16,8 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
const theme = useStore(state => state.theme);
|
||||
const appearance = useStore(state => state.appearance);
|
||||
const darkMode = theme === 'dark';
|
||||
const opacity = normalizeOpacityForPlatform(appearance.opacity);
|
||||
const resolvedAppearance = resolveAppearanceValues(appearance);
|
||||
const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity);
|
||||
|
||||
// Background Helper
|
||||
const getBg = (darkHex: string) => {
|
||||
@@ -28,10 +29,25 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
return `rgba(${r}, ${g}, ${b}, ${opacity})`;
|
||||
};
|
||||
const bgMain = getBg('#1d1d1d');
|
||||
const panelDividerColor = darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.08)';
|
||||
const shellOpacity = darkMode ? Math.max(0.18, opacity * 0.82) : Math.max(0.28, opacity * 0.92);
|
||||
const shellOpacityStrong = darkMode ? Math.max(0.22, opacity * 0.9) : Math.max(0.34, opacity * 0.96);
|
||||
const panelDividerColor = darkMode
|
||||
? `rgba(255,255,255,${Math.max(0.04, opacity * 0.10)})`
|
||||
: `rgba(0,0,0,${Math.max(0.04, opacity * 0.08)})`;
|
||||
const panelMutedTextColor = darkMode ? 'rgba(255,255,255,0.62)' : 'rgba(0,0,0,0.58)';
|
||||
const logScrollbarThumb = darkMode ? 'rgba(255, 255, 255, 0.34)' : 'rgba(0, 0, 0, 0.26)';
|
||||
const logScrollbarThumbHover = darkMode ? 'rgba(255, 255, 255, 0.5)' : 'rgba(0, 0, 0, 0.36)';
|
||||
const panelShellBg = darkMode
|
||||
? `linear-gradient(180deg, rgba(15,20,30,${shellOpacity}) 0%, rgba(9,13,22,${shellOpacityStrong}) 100%)`
|
||||
: `linear-gradient(180deg, rgba(255,255,255,${shellOpacityStrong}) 0%, rgba(246,248,252,${shellOpacity}) 100%)`;
|
||||
const panelAccentColor = darkMode ? '#ffd666' : '#1677ff';
|
||||
const panelShadow = darkMode
|
||||
? `0 12px 28px rgba(0,0,0,${Math.max(0.05, opacity * 0.18)})`
|
||||
: `0 12px 24px rgba(15,23,42,${Math.max(0.02, opacity * 0.08)})`;
|
||||
const logScrollbarThumb = darkMode
|
||||
? `rgba(255, 255, 255, ${Math.max(0.18, opacity * 0.34)})`
|
||||
: `rgba(0, 0, 0, ${Math.max(0.12, opacity * 0.26)})`;
|
||||
const logScrollbarThumbHover = darkMode
|
||||
? `rgba(255, 255, 255, ${Math.max(0.28, opacity * 0.48)})`
|
||||
: `rgba(0, 0, 0, ${Math.max(0.18, opacity * 0.36)})`;
|
||||
|
||||
const columns = [
|
||||
{
|
||||
@@ -45,7 +61,7 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
dataIndex: 'status',
|
||||
width: 70,
|
||||
render: (status: string) => (
|
||||
<Tag color={status === 'success' ? 'success' : 'error'} style={{ marginRight: 0 }}>
|
||||
<Tag color={status === 'success' ? 'success' : 'error'} style={{ marginRight: 0, borderRadius: 999, paddingInline: 8, fontSize: 11, fontWeight: 700 }}>
|
||||
{status === 'success' ? 'OK' : 'ERR'}
|
||||
</Tag>
|
||||
)
|
||||
@@ -60,7 +76,7 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
title: 'SQL / Message',
|
||||
dataIndex: 'sql',
|
||||
render: (text: string, record: any) => (
|
||||
<div style={{ fontFamily: 'monospace', wordBreak: 'break-all', fontSize: '12px', lineHeight: '1.2' }}>
|
||||
<div style={{ fontFamily: 'monospace', wordBreak: 'break-all', fontSize: '12px', lineHeight: '1.45' }}>
|
||||
<div style={{ color: darkMode ? '#a6e22e' : '#005cc5' }}>{text}</div>
|
||||
{record.message && <div style={{ color: '#ff4d4f', marginTop: 2 }}>{record.message}</div>}
|
||||
{record.affectedRows !== undefined && <div style={{ color: panelMutedTextColor, marginTop: 1 }}>Affected: {record.affectedRows}</div>}
|
||||
@@ -72,12 +88,18 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
return (
|
||||
<div style={{
|
||||
height,
|
||||
borderTop: `1px solid ${panelDividerColor}`,
|
||||
background: bgMain,
|
||||
margin: 0,
|
||||
border: `1px solid ${panelDividerColor}`,
|
||||
borderRadius: 14,
|
||||
background: panelShellBg,
|
||||
WebkitBackdropFilter: opacity < 0.999 ? 'blur(14px)' : 'none',
|
||||
boxShadow: panelShadow,
|
||||
backdropFilter: darkMode && opacity < 0.999 ? 'blur(18px)' : 'none',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
position: 'relative',
|
||||
zIndex: 100 // Ensure above other content
|
||||
overflow: 'hidden',
|
||||
zIndex: 100
|
||||
}}>
|
||||
{/* Resize Handle */}
|
||||
<div
|
||||
@@ -95,38 +117,53 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
|
||||
{/* Toolbar */}
|
||||
<div style={{
|
||||
padding: '4px 8px',
|
||||
padding: '10px 14px',
|
||||
borderBottom: `1px solid ${panelDividerColor}`,
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
height: 32
|
||||
gap: 12,
|
||||
minHeight: 48
|
||||
}}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8, fontWeight: 'bold', fontSize: '12px' }}>
|
||||
<BugOutlined /> SQL 执行日志
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 10, minWidth: 0 }}>
|
||||
<div style={{ width: 30, height: 30, borderRadius: 10, display: 'grid', placeItems: 'center', background: darkMode ? `rgba(255,214,102,${Math.max(0.10, Math.min(0.18, opacity * 0.18))})` : `rgba(24,144,255,${Math.max(0.08, Math.min(0.16, opacity * 0.16))})`, color: panelAccentColor, flexShrink: 0 }}>
|
||||
<BugOutlined />
|
||||
</div>
|
||||
<div style={{ minWidth: 0 }}>
|
||||
<div style={{ fontWeight: 700, fontSize: 13, color: darkMode ? '#f5f7ff' : '#162033' }}>SQL 执行日志</div>
|
||||
<div style={{ fontSize: 12, color: panelMutedTextColor }}>记录执行状态、耗时与错误信息,便于快速回溯。</div>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 6 }}>
|
||||
<Tooltip title="清空日志">
|
||||
<Button type="text" size="small" icon={<ClearOutlined />} onClick={clearSqlLogs} />
|
||||
<Button type="text" size="small" icon={<ClearOutlined />} onClick={clearSqlLogs} style={{ color: panelMutedTextColor }} />
|
||||
</Tooltip>
|
||||
<Tooltip title="关闭面板">
|
||||
<Button type="text" size="small" icon={<CloseOutlined />} onClick={onClose} />
|
||||
<Button type="text" size="small" icon={<CloseOutlined />} onClick={onClose} style={{ color: panelMutedTextColor }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* List */}
|
||||
<div className="log-panel-scroll" style={{ flex: 1, overflow: 'auto' }}>
|
||||
<Table
|
||||
className="log-panel-table"
|
||||
dataSource={sqlLogs}
|
||||
columns={columns}
|
||||
size="small"
|
||||
pagination={false}
|
||||
rowKey="id"
|
||||
showHeader={false}
|
||||
// scroll={{ y: height - 32 }} // Let flex handle it
|
||||
/>
|
||||
<div className="log-panel-scroll" style={{ flex: 1, overflow: 'auto', padding: '8px 10px 10px' }}>
|
||||
{sqlLogs.length === 0 ? (
|
||||
<div style={{ height: '100%', minHeight: 160, display: 'grid', placeItems: 'center' }}>
|
||||
<Empty
|
||||
image={Empty.PRESENTED_IMAGE_SIMPLE}
|
||||
description={<span style={{ color: panelMutedTextColor }}>暂无 SQL 执行日志</span>}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<Table
|
||||
className="log-panel-table"
|
||||
dataSource={sqlLogs}
|
||||
columns={columns}
|
||||
size="small"
|
||||
pagination={false}
|
||||
rowKey="id"
|
||||
showHeader={false}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<style>{`
|
||||
.log-panel-scroll {
|
||||
@@ -156,6 +193,16 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
.log-panel-table .ant-table-tbody > tr > td {
|
||||
background: transparent !important;
|
||||
}
|
||||
.log-panel-table .ant-table-tbody > tr > td {
|
||||
padding: 8px 10px !important;
|
||||
border-bottom: 1px solid ${panelDividerColor} !important;
|
||||
}
|
||||
.log-panel-table .ant-table-tbody > tr:last-child > td {
|
||||
border-bottom: none !important;
|
||||
}
|
||||
.log-panel-table .ant-table-row:hover > td {
|
||||
background: ${darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(16,24,40,0.03)'} !important;
|
||||
}
|
||||
`}</style>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -6,12 +6,20 @@ import { format } from 'sql-formatter';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { TabData, ColumnDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery, DBQueryWithCancel, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App';
|
||||
import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
|
||||
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
|
||||
|
||||
const SQL_KEYWORDS = [
|
||||
'SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT',
|
||||
'INNER', 'OUTER', 'ON', 'GROUP BY', 'ORDER BY', 'AS', 'AND', 'OR', 'NOT', 'NULL', 'IS',
|
||||
'IN', 'VALUES', 'SET', 'CREATE', 'TABLE', 'DROP', 'ALTER', 'ADD', 'MODIFY', 'CHANGE',
|
||||
'COLUMN', 'KEY', 'PRIMARY', 'FOREIGN', 'REFERENCES', 'CONSTRAINT', 'DEFAULT', 'AUTO_INCREMENT',
|
||||
'COMMENT', 'SHOW', 'DESCRIBE', 'EXPLAIN',
|
||||
];
|
||||
|
||||
const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [query, setQuery] = useState(tab.query || 'SELECT * FROM ');
|
||||
|
||||
@@ -33,7 +41,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [activeResultKey, setActiveResultKey] = useState<string>('');
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [currentQueryId, setCurrentQueryId] = useState<string>('');
|
||||
const [, setCurrentQueryId] = useState<string>('');
|
||||
const runSeqRef = useRef(0);
|
||||
const currentQueryIdRef = useRef('');
|
||||
const [isSaveModalOpen, setIsSaveModalOpen] = useState(false);
|
||||
@@ -48,7 +56,10 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [editorHeight, setEditorHeight] = useState(300);
|
||||
const editorRef = useRef<any>(null);
|
||||
const monacoRef = useRef<any>(null);
|
||||
const lastExternalQueryRef = useRef<string>(tab.query || '');
|
||||
const dragRef = useRef<{ startY: number, startHeight: number } | null>(null);
|
||||
const queryEditorRootRef = useRef<HTMLDivElement | null>(null);
|
||||
const editorPaneRef = useRef<HTMLDivElement | null>(null);
|
||||
const tablesRef = useRef<{dbName: string, tableName: string}[]>([]); // Store tables for autocomplete (cross-db)
|
||||
const allColumnsRef = useRef<{dbName: string, tableName: string, name: string, type: string}[]>([]); // Store all columns (cross-db)
|
||||
const visibleDbsRef = useRef<string[]>([]); // Store visible databases for cross-db intellisense
|
||||
@@ -59,6 +70,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
[connections]
|
||||
);
|
||||
const addSqlLog = useStore(state => state.addSqlLog);
|
||||
const addTab = useStore(state => state.addTab);
|
||||
const savedQueries = useStore(state => state.savedQueries);
|
||||
const currentConnectionIdRef = useRef(currentConnectionId);
|
||||
const currentDbRef = useRef(currentDb);
|
||||
const connectionsRef = useRef(connections);
|
||||
@@ -73,6 +86,18 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const shortcutOptions = useStore(state => state.shortcutOptions);
|
||||
const activeTabId = useStore(state => state.activeTabId);
|
||||
|
||||
const currentSavedQuery = useMemo(() => {
|
||||
const savedId = String(tab.savedQueryId || '').trim();
|
||||
if (savedId) {
|
||||
return savedQueries.find((item) => item.id === savedId) || null;
|
||||
}
|
||||
const tabId = String(tab.id || '').trim();
|
||||
if (!tabId) {
|
||||
return null;
|
||||
}
|
||||
return savedQueries.find((item) => item.id === tabId) || null;
|
||||
}, [savedQueries, tab.id, tab.savedQueryId]);
|
||||
|
||||
useEffect(() => {
|
||||
currentConnectionIdRef.current = currentConnectionId;
|
||||
}, [currentConnectionId]);
|
||||
@@ -95,10 +120,30 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
connectionsRef.current = connections;
|
||||
}, [connections]);
|
||||
|
||||
const getCurrentQuery = () => {
|
||||
const val = editorRef.current?.getValue?.();
|
||||
if (typeof val === 'string') return val;
|
||||
return query || '';
|
||||
};
|
||||
|
||||
const syncQueryToEditor = (sql: string) => {
|
||||
const next = sql || '';
|
||||
setQuery(next);
|
||||
const editor = editorRef.current;
|
||||
if (editor && editor.getValue?.() !== next) {
|
||||
editor.setValue(next);
|
||||
}
|
||||
};
|
||||
|
||||
// If opening a saved query, load its SQL
|
||||
useEffect(() => {
|
||||
if (tab.query) setQuery(tab.query);
|
||||
}, [tab.query]);
|
||||
const incoming = tab.query || '';
|
||||
if (incoming === lastExternalQueryRef.current) {
|
||||
return;
|
||||
}
|
||||
lastExternalQueryRef.current = incoming;
|
||||
syncQueryToEditor(incoming || 'SELECT * FROM ');
|
||||
}, [tab.id, tab.query]);
|
||||
|
||||
// Fetch Database List
|
||||
useEffect(() => {
|
||||
@@ -138,7 +183,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
setDbList([]);
|
||||
}
|
||||
};
|
||||
fetchDbs();
|
||||
void fetchDbs();
|
||||
}, [currentConnectionId, connections]);
|
||||
|
||||
// Fetch Metadata for Autocomplete (Cross-database)
|
||||
@@ -190,7 +235,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
tablesRef.current = allTables;
|
||||
allColumnsRef.current = allColumns;
|
||||
};
|
||||
fetchMetadata();
|
||||
void fetchMetadata();
|
||||
}, [currentConnectionId, connections, dbList]); // dbList 变化时触发重新加载
|
||||
|
||||
// Query ID management helpers
|
||||
@@ -325,7 +370,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const linePrefix = model.getLineContent(position.lineNumber).slice(0, position.column - 1);
|
||||
|
||||
// 0) 三段式 db.table.column 格式:当输入 db.table. 时提示列
|
||||
const threePartMatch = linePrefix.match(/([`"]?[\w]+[`"]?)\.([`"]?[\w]+[`"]?)\.(\w*)$/);
|
||||
const threePartMatch = linePrefix.match(/([`"]?\w+[`"]?)\.([`"]?\w+[`"]?)\.(\w*)$/);
|
||||
if (threePartMatch) {
|
||||
const dbPart = stripQuotes(threePartMatch[1]);
|
||||
const tablePart = stripQuotes(threePartMatch[2]);
|
||||
@@ -353,7 +398,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
|
||||
// 1) 两段式 qualifier.xxx 格式
|
||||
const qualifierMatch = linePrefix.match(/([`"]?[A-Za-z_][\w]*[`"]?)\.(\w*)$/);
|
||||
const qualifierMatch = linePrefix.match(/([`"]?[A-Za-z_]\w*[`"]?)\.(\w*)$/);
|
||||
if (qualifierMatch) {
|
||||
const qualifier = stripQuotes(qualifierMatch[1]);
|
||||
const prefix = (qualifierMatch[2] || '').toLowerCase();
|
||||
@@ -418,7 +463,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const aliasMap: Record<string, {dbName: string, tableName: string}> = {};
|
||||
// Capture table and optional alias, support db.table format
|
||||
const aliasRegex = /\b(?:FROM|JOIN|UPDATE|INTO|DELETE\s+FROM)\s+([`"]?[\w]+[`"]?(?:\s*\.\s*[`"]?[\w]+[`"]?)?)(?:\s+(?:AS\s+)?([`"]?[\w]+[`"]?))?/gi;
|
||||
const aliasRegex = /\b(?:FROM|JOIN|UPDATE|INTO|DELETE\s+FROM)\s+([`"]?\w+[`"]?(?:\s*\.\s*[`"]?\w+[`"]?)?)(?:\s+(?:AS\s+)?([`"]?\w+[`"]?))?/gi;
|
||||
let m;
|
||||
while ((m = aliasRegex.exec(fullText)) !== null) {
|
||||
const tableIdent = normalizeQualifiedName(m[1] || '');
|
||||
@@ -447,7 +492,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const tableInfo = aliasMap[qualifier.toLowerCase()];
|
||||
if (tableInfo) {
|
||||
// Prefer preloaded MySQL all-columns cache
|
||||
let cols: { name: string, type?: string, tableName?: string, dbName?: string }[] = [];
|
||||
let cols: { name: string, type?: string, tableName?: string, dbName?: string }[];
|
||||
if (allColumnsRef.current.length > 0) {
|
||||
cols = allColumnsRef.current
|
||||
.filter(c =>
|
||||
@@ -477,7 +522,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
|
||||
// 2) global/table/column completion
|
||||
const tableRegex = /\b(?:FROM|JOIN|UPDATE|INTO|DELETE\s+FROM)\s+([`"]?[\w]+[`"]?(?:\s*\.\s*[`"]?[\w]+[`"]?)?)/gi;
|
||||
const tableRegex = /\b(?:FROM|JOIN|UPDATE|INTO|DELETE\s+FROM)\s+([`"]?\w+[`"]?(?:\s*\.\s*[`"]?\w+[`"]?)?)/gi;
|
||||
const foundTables = new Set<string>();
|
||||
let match;
|
||||
while ((match = tableRegex.exec(fullText)) !== null) {
|
||||
@@ -488,6 +533,17 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
|
||||
const currentDatabase = currentDbRef.current || '';
|
||||
const wordPrefix = (word.word || '').toLowerCase();
|
||||
const startsWithPrefix = (candidate: string) => !wordPrefix || candidate.toLowerCase().startsWith(wordPrefix);
|
||||
const expectsTableName = /\b(?:FROM|JOIN|UPDATE|INTO|DELETE\s+FROM|TABLE|DESCRIBE|DESC|EXPLAIN)\s+[`"]?[\w.]*$/i.test(linePrefix.trim());
|
||||
const shouldBoostKeywords = !expectsTableName
|
||||
&& wordPrefix.length > 0
|
||||
&& SQL_KEYWORDS.some((keyword) => keyword.toLowerCase().startsWith(wordPrefix));
|
||||
const sortGroups = shouldBoostKeywords
|
||||
? { keyword: '00', columnCurrent: '10', columnOther: '11', tableCurrent: '20', tableOther: '21', db: '30' }
|
||||
: expectsTableName
|
||||
? { keyword: '20', columnCurrent: '10', columnOther: '11', tableCurrent: '00', tableOther: '01', db: '30' }
|
||||
: { keyword: '30', columnCurrent: '00', columnOther: '01', tableCurrent: '10', tableOther: '11', db: '20' };
|
||||
|
||||
// 相关列提示:匹配 SQL 中引用的表(FROM/JOIN 等)
|
||||
// 权重最高,输入 WHERE 条件时优先显示
|
||||
@@ -495,7 +551,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
.filter(c => {
|
||||
const fullIdent = `${c.dbName}.${c.tableName}`.toLowerCase();
|
||||
const shortIdent = (c.tableName || '').toLowerCase();
|
||||
return foundTables.has(fullIdent) || foundTables.has(shortIdent);
|
||||
return (foundTables.has(fullIdent) || foundTables.has(shortIdent)) && startsWithPrefix(c.name || '');
|
||||
})
|
||||
.map(c => {
|
||||
// 当前库的表字段优先级更高
|
||||
@@ -506,12 +562,18 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
insertText: c.name,
|
||||
detail: `${c.type} (${c.dbName}.${c.tableName})`,
|
||||
range,
|
||||
sortText: isCurrentDb ? '00' + c.name : '01' + c.name // FROM 表字段最优先
|
||||
sortText: isCurrentDb ? sortGroups.columnCurrent + c.name : sortGroups.columnOther + c.name,
|
||||
};
|
||||
});
|
||||
|
||||
// 表提示:当前库显示表名,其他库显示 db.table 格式
|
||||
const tableSuggestions = tablesRef.current.map(t => {
|
||||
const tableSuggestions = tablesRef.current
|
||||
.filter(t => {
|
||||
const isCurrentDb = (t.dbName || '').toLowerCase() === currentDatabase.toLowerCase();
|
||||
const label = isCurrentDb ? t.tableName : `${t.dbName}.${t.tableName}`;
|
||||
return startsWithPrefix(label || '');
|
||||
})
|
||||
.map(t => {
|
||||
const isCurrentDb = (t.dbName || '').toLowerCase() === currentDatabase.toLowerCase();
|
||||
const label = isCurrentDb ? t.tableName : `${t.dbName}.${t.tableName}`;
|
||||
const insertText = isCurrentDb ? t.tableName : `${t.dbName}.${t.tableName}`;
|
||||
@@ -521,27 +583,31 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
insertText,
|
||||
detail: `Table (${t.dbName})`,
|
||||
range,
|
||||
sortText: isCurrentDb ? '10' + t.tableName : '11' + t.tableName // 表次优先
|
||||
sortText: isCurrentDb ? sortGroups.tableCurrent + t.tableName : sortGroups.tableOther + t.tableName,
|
||||
};
|
||||
});
|
||||
|
||||
// 数据库提示
|
||||
const dbSuggestions = visibleDbsRef.current.map(db => ({
|
||||
label: db,
|
||||
kind: monaco.languages.CompletionItemKind.Module,
|
||||
insertText: db,
|
||||
detail: 'Database',
|
||||
range,
|
||||
sortText: '20' + db // 数据库最后
|
||||
}));
|
||||
const dbSuggestions = visibleDbsRef.current
|
||||
.filter((db) => startsWithPrefix(db))
|
||||
.map(db => ({
|
||||
label: db,
|
||||
kind: monaco.languages.CompletionItemKind.Module,
|
||||
insertText: db,
|
||||
detail: 'Database',
|
||||
range,
|
||||
sortText: sortGroups.db + db,
|
||||
}));
|
||||
|
||||
// 关键字提示
|
||||
const keywordSuggestions = ['SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'GROUP BY', 'ORDER BY', 'AS', 'AND', 'OR', 'NOT', 'NULL', 'IS', 'IN', 'VALUES', 'SET', 'CREATE', 'TABLE', 'DROP', 'ALTER', 'Add', 'MODIFY', 'CHANGE', 'COLUMN', 'KEY', 'PRIMARY', 'FOREIGN', 'REFERENCES', 'CONSTRAINT', 'DEFAULT', 'AUTO_INCREMENT', 'COMMENT', 'SHOW', 'DESCRIBE', 'EXPLAIN'].map(k => ({
|
||||
const keywordSuggestions = SQL_KEYWORDS
|
||||
.filter((k) => startsWithPrefix(k))
|
||||
.map(k => ({
|
||||
label: k,
|
||||
kind: monaco.languages.CompletionItemKind.Keyword,
|
||||
insertText: k,
|
||||
range,
|
||||
sortText: '30' + k // 关键字权重最低
|
||||
sortText: sortGroups.keyword + k,
|
||||
}));
|
||||
|
||||
const suggestions = [
|
||||
@@ -557,10 +623,10 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const handleFormat = () => {
|
||||
try {
|
||||
const formatted = format(query, { language: 'mysql', keywordCase: sqlFormatOptions.keywordCase });
|
||||
setQuery(formatted);
|
||||
const formatted = format(getCurrentQuery(), { language: 'mysql', keywordCase: sqlFormatOptions.keywordCase });
|
||||
syncQueryToEditor(formatted);
|
||||
} catch (e) {
|
||||
message.error("格式化失败: SQL 语法可能有误");
|
||||
void message.error("格式化失败: SQL 语法可能有误");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -710,6 +776,9 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
return statements;
|
||||
};
|
||||
|
||||
// DEBT: 改用 DBQueryMulti 后前端不再逐条处理语句,此函数暂时未使用。
|
||||
// 当恢复前端自动行数限制功能时需要启用。
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const getLeadingKeyword = (sql: string): string => {
|
||||
const text = (sql || '').replace(/\r\n/g, '\n');
|
||||
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
|
||||
@@ -1002,6 +1071,9 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
return -1;
|
||||
};
|
||||
|
||||
// DEBT: 改用 DBQueryMulti 后前端不再逐条处理语句,此函数暂时未使用。
|
||||
// 当恢复前端自动行数限制功能时需要启用。
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const applyAutoLimit = (sql: string, dbType: string, maxRows: number): { sql: string; applied: boolean; maxRows: number } => {
|
||||
const normalizedType = (dbType || 'mysql').toLowerCase();
|
||||
const supportsLimit = normalizedType === 'mysql' || normalizedType === 'mariadb' || normalizedType === 'diros' || normalizedType === 'sphinx' || normalizedType === 'postgres' || normalizedType === 'kingbase' || normalizedType === 'sqlite' || normalizedType === 'duckdb' || normalizedType === 'tdengine' || normalizedType === 'clickhouse' || normalizedType === '';
|
||||
@@ -1045,7 +1117,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
};
|
||||
|
||||
const handleRun = async () => {
|
||||
if (!query.trim()) return;
|
||||
const currentQuery = getCurrentQuery();
|
||||
if (!currentQuery.trim()) return;
|
||||
if (!currentDb) {
|
||||
message.error("请先选择数据库");
|
||||
return;
|
||||
@@ -1086,40 +1159,35 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
};
|
||||
|
||||
try {
|
||||
const rawSQL = getSelectedSQL() || query;
|
||||
const rawSQL = getSelectedSQL() || currentQuery;
|
||||
const dbType = String((config as any).type || 'mysql');
|
||||
const normalizedDbType = dbType.trim().toLowerCase();
|
||||
const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';');
|
||||
const splitInput = normalizedDbType === 'mongodb'
|
||||
? normalizedRawSQL
|
||||
|
||||
// MongoDB 仍走逐条执行的旧路径
|
||||
const isMongoDB = normalizedDbType === 'mongodb';
|
||||
|
||||
if (isMongoDB) {
|
||||
// MongoDB: 保持逐条执行
|
||||
const splitInput = normalizedRawSQL
|
||||
.replace(/^\s*\/\/.*$/gm, '')
|
||||
.replace(/^\s*#.*$/gm, '')
|
||||
: normalizedRawSQL;
|
||||
const statements = splitSQLStatements(splitInput);
|
||||
if (statements.length === 0) {
|
||||
message.info('没有可执行的 SQL。');
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
return;
|
||||
}
|
||||
.replace(/^\s*#.*$/gm, '');
|
||||
const statements = splitSQLStatements(splitInput);
|
||||
if (statements.length === 0) {
|
||||
message.info('没有可执行的 SQL。');
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
return;
|
||||
}
|
||||
|
||||
const nextResultSets: ResultSet[] = [];
|
||||
const maxRows = Number(queryOptions?.maxRows) || 0;
|
||||
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
|
||||
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
|
||||
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
|
||||
let anyTruncated = false;
|
||||
const pendingPk: Array<{ resultKey: string; tableName: string }> = [];
|
||||
const nextResultSets: ResultSet[] = [];
|
||||
const maxRows = Number(queryOptions?.maxRows) || 0;
|
||||
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
|
||||
let anyTruncated = false;
|
||||
|
||||
for (let idx = 0; idx < statements.length; idx++) {
|
||||
const rawStatement = statements[idx];
|
||||
const leadingKeyword = getLeadingKeyword(rawStatement);
|
||||
const shouldAutoLimit = leadingKeyword === 'select' || leadingKeyword === 'with';
|
||||
|
||||
const limitApplied = shouldAutoLimit && wantsLimitProbe;
|
||||
const limited = limitApplied ? applyAutoLimit(rawStatement, dbType, probeLimit) : { sql: rawStatement, applied: false, maxRows: probeLimit };
|
||||
let executedSql = limited.sql;
|
||||
if (String(dbType || '').trim().toLowerCase() === 'mongodb') {
|
||||
for (let idx = 0; idx < statements.length; idx++) {
|
||||
const rawStatement = statements[idx];
|
||||
let executedSql = rawStatement;
|
||||
const shellConvert = convertMongoShellToJsonCommand(executedSql);
|
||||
if (shellConvert.recognized) {
|
||||
if (shellConvert.error) {
|
||||
@@ -1133,10 +1201,97 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
executedSql = shellConvert.command;
|
||||
}
|
||||
}
|
||||
}
|
||||
const startTime = Date.now();
|
||||
const startTime = Date.now();
|
||||
let queryId: string;
|
||||
try {
|
||||
queryId = await GenerateQueryID();
|
||||
} catch (error) {
|
||||
console.warn('GenerateQueryID failed, using local UUID fallback:', error);
|
||||
queryId = 'query-' + uuidv4();
|
||||
}
|
||||
setQueryId(queryId);
|
||||
|
||||
// Generate query ID for cancellation using backend UUID with fallback
|
||||
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
|
||||
const duration = Date.now() - startTime;
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-query-${idx + 1}`,
|
||||
timestamp: Date.now(),
|
||||
sql: executedSql,
|
||||
status: res.success ? 'success' : 'error',
|
||||
duration,
|
||||
message: res.success ? '' : res.message,
|
||||
affectedRows: (res.success && !Array.isArray(res.data)) ? (res.data as any).affectedRows : (Array.isArray(res.data) ? res.data.length : undefined),
|
||||
dbName: currentDb
|
||||
});
|
||||
if (!res.success) {
|
||||
const prefix = statements.length > 1 ? `第 ${idx + 1} 条语句执行失败:` : '';
|
||||
message.error(prefix + res.message);
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
return;
|
||||
}
|
||||
if (Array.isArray(res.data)) {
|
||||
let rows = (res.data as any[]) || [];
|
||||
let truncated = false;
|
||||
if (wantsLimitProbe && Number.isFinite(maxRows) && maxRows > 0 && rows.length > maxRows) {
|
||||
truncated = true;
|
||||
anyTruncated = true;
|
||||
rows = rows.slice(0, maxRows);
|
||||
}
|
||||
const cols = (res.fields && res.fields.length > 0)
|
||||
? (res.fields as string[])
|
||||
: (rows.length > 0 ? Object.keys(rows[0]) : []);
|
||||
rows.forEach((row: any, i: number) => {
|
||||
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
|
||||
});
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
rows,
|
||||
columns: cols,
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
truncated
|
||||
});
|
||||
} else {
|
||||
const affected = Number((res.data as any)?.affectedRows);
|
||||
if (Number.isFinite(affected)) {
|
||||
const row = { affectedRows: affected };
|
||||
(row as any)[GONAVI_ROW_KEY] = 0;
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
rows: [row],
|
||||
columns: ['affectedRows'],
|
||||
pkColumns: [],
|
||||
readOnly: true
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
setResultSets(nextResultSets);
|
||||
setActiveResultKey(nextResultSets[0]?.key || '');
|
||||
if (statements.length > 1) {
|
||||
message.success(`已执行 ${statements.length} 条语句,生成 ${nextResultSets.length} 个结果集。`);
|
||||
} else if (nextResultSets.length === 0) {
|
||||
message.success('执行成功。');
|
||||
}
|
||||
if (anyTruncated && maxRows > 0) {
|
||||
message.warning(`结果集已自动限制为最多 ${maxRows} 行(可在工具栏调整)。`);
|
||||
}
|
||||
} else {
|
||||
// 非 MongoDB:使用 DBQueryMulti 一次性执行多条 SQL,后端返回多结果集
|
||||
const fullSQL = normalizedRawSQL;
|
||||
if (!fullSQL.trim()) {
|
||||
message.info('没有可执行的 SQL。');
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
return;
|
||||
}
|
||||
|
||||
const startTime = Date.now();
|
||||
let queryId: string;
|
||||
try {
|
||||
queryId = await GenerateQueryID();
|
||||
@@ -1146,22 +1301,20 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
setQueryId(queryId);
|
||||
|
||||
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
|
||||
const res = await DBQueryMulti(config as any, currentDb, fullSQL, queryId);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-query-${idx + 1}`,
|
||||
id: `log-${Date.now()}-query-multi`,
|
||||
timestamp: Date.now(),
|
||||
sql: executedSql,
|
||||
sql: fullSQL,
|
||||
status: res.success ? 'success' : 'error',
|
||||
duration,
|
||||
message: res.success ? '' : res.message,
|
||||
affectedRows: (res.success && !Array.isArray(res.data)) ? (res.data as any).affectedRows : (Array.isArray(res.data) ? res.data.length : undefined),
|
||||
dbName: currentDb
|
||||
});
|
||||
|
||||
if (!res.success) {
|
||||
// 检查是否为查询取消错误
|
||||
const errorMsg = res.message.toLowerCase();
|
||||
const isCancelledError = errorMsg.includes('context canceled') ||
|
||||
errorMsg.includes('查询已取消') ||
|
||||
@@ -1169,72 +1322,49 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
errorMsg.includes('cancelled') ||
|
||||
errorMsg.includes('statement canceled') ||
|
||||
errorMsg.includes('sql: statement canceled');
|
||||
|
||||
// 确保不是超时错误
|
||||
const isTimeoutError = errorMsg.includes('context deadline exceeded') ||
|
||||
errorMsg.includes('timeout') ||
|
||||
errorMsg.includes('超时') ||
|
||||
errorMsg.includes('deadline exceeded');
|
||||
|
||||
if (isCancelledError && !isTimeoutError) {
|
||||
// 查询已被用户取消,不显示错误消息,清理状态
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
// 清除查询ID,与handleCancel保持一致
|
||||
if (currentQueryIdRef.current) {
|
||||
clearQueryId();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const prefix = statements.length > 1 ? `第 ${idx + 1} 条语句执行失败:` : '';
|
||||
message.error(prefix + res.message);
|
||||
message.error(res.message);
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
return;
|
||||
}
|
||||
|
||||
if (Array.isArray(res.data)) {
|
||||
let rows = (res.data as any[]) || [];
|
||||
let truncated = false;
|
||||
if (limited.applied && Number.isFinite(maxRows) && maxRows > 0 && rows.length > maxRows) {
|
||||
truncated = true;
|
||||
anyTruncated = true;
|
||||
rows = rows.slice(0, maxRows);
|
||||
}
|
||||
const cols = (res.fields && res.fields.length > 0)
|
||||
? (res.fields as string[])
|
||||
: (rows.length > 0 ? Object.keys(rows[0]) : []);
|
||||
// res.data 是 ResultSetData[] 数组
|
||||
const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : [];
|
||||
const nextResultSets: ResultSet[] = [];
|
||||
const maxRows = Number(queryOptions?.maxRows) || 0;
|
||||
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
|
||||
let anyTruncated = false;
|
||||
const pendingPk: Array<{ resultKey: string; tableName: string }> = [];
|
||||
|
||||
rows.forEach((row: any, i: number) => {
|
||||
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
|
||||
});
|
||||
// 前端也拆分语句用于匹配原始 SQL(展示和表名检测)
|
||||
const statements = splitSQLStatements(fullSQL);
|
||||
|
||||
let simpleTableName: string | undefined = undefined;
|
||||
const tableMatch = rawStatement.match(/^\s*SELECT\s+\*\s+FROM\s+[`"]?(\w+)[`"]?\s*(?:WHERE.*)?(?:ORDER BY.*)?(?:LIMIT.*)?$/i);
|
||||
if (tableMatch) {
|
||||
simpleTableName = tableMatch[1];
|
||||
if (!forceReadOnlyResult) {
|
||||
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
|
||||
}
|
||||
}
|
||||
for (let idx = 0; idx < resultSetDataArray.length; idx++) {
|
||||
const rsData = resultSetDataArray[idx];
|
||||
const rawStatement = (idx < statements.length) ? statements[idx] : '';
|
||||
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: limited.applied ? applyAutoLimit(rawStatement, dbType, Math.max(1, Number(maxRows) || 1)).sql : rawStatement,
|
||||
rows,
|
||||
columns: cols,
|
||||
tableName: simpleTableName,
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
pkLoading: !!simpleTableName,
|
||||
truncated
|
||||
});
|
||||
} else {
|
||||
const affected = Number((res.data as any)?.affectedRows);
|
||||
if (Number.isFinite(affected)) {
|
||||
const row = { affectedRows: affected };
|
||||
// 检查是否为 affectedRows 类结果集
|
||||
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
|
||||
&& rsData.columns && rsData.columns.length === 1
|
||||
&& rsData.columns[0] === 'affectedRows';
|
||||
|
||||
if (isAffectedResult) {
|
||||
const affected = Number(rsData.rows[0]?.affectedRows);
|
||||
const row = { affectedRows: Number.isFinite(affected) ? affected : 0 };
|
||||
(row as any)[GONAVI_ROW_KEY] = 0;
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
@@ -1245,37 +1375,80 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
pkColumns: [],
|
||||
readOnly: true
|
||||
});
|
||||
} else {
|
||||
let rows = Array.isArray(rsData.rows) ? rsData.rows : [];
|
||||
let truncated = false;
|
||||
if (Number.isFinite(maxRows) && maxRows > 0 && rows.length > maxRows) {
|
||||
truncated = true;
|
||||
anyTruncated = true;
|
||||
rows = rows.slice(0, maxRows);
|
||||
}
|
||||
const cols = (rsData.columns && rsData.columns.length > 0)
|
||||
? rsData.columns
|
||||
: (rows.length > 0 ? Object.keys(rows[0]) : []);
|
||||
|
||||
rows.forEach((row: any, i: number) => {
|
||||
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i;
|
||||
});
|
||||
|
||||
let simpleTableName: string | undefined = undefined;
|
||||
if (rawStatement) {
|
||||
const tableMatch = rawStatement.match(/^\s*SELECT\s+\*\s+FROM\s+[`"]?(\w+)[`"]?\s*(?:WHERE.*)?(?:ORDER BY.*)?(?:LIMIT.*)?$/i);
|
||||
if (tableMatch) {
|
||||
simpleTableName = tableMatch[1];
|
||||
if (!forceReadOnlyResult) {
|
||||
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
rows,
|
||||
columns: cols,
|
||||
tableName: simpleTableName,
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
pkLoading: !!simpleTableName,
|
||||
truncated
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setResultSets(nextResultSets);
|
||||
setActiveResultKey(nextResultSets[0]?.key || '');
|
||||
setResultSets(nextResultSets);
|
||||
setActiveResultKey(nextResultSets[0]?.key || '');
|
||||
|
||||
pendingPk.forEach(({ resultKey, tableName }) => {
|
||||
DBGetColumns(config as any, currentDb, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (runSeqRef.current !== runSeq) return;
|
||||
if (!resCols?.success) {
|
||||
pendingPk.forEach(({ resultKey, tableName }) => {
|
||||
DBGetColumns(config as any, currentDb, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (runSeqRef.current !== runSeq) return;
|
||||
if (!resCols?.success) {
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
|
||||
return;
|
||||
}
|
||||
const primaryKeys = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkColumns: primaryKeys, pkLoading: false, readOnly: false } : rs));
|
||||
})
|
||||
.catch(() => {
|
||||
if (runSeqRef.current !== runSeq) return;
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
|
||||
return;
|
||||
}
|
||||
const primaryKeys = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkColumns: primaryKeys, pkLoading: false, readOnly: false } : rs));
|
||||
})
|
||||
.catch(() => {
|
||||
if (runSeqRef.current !== runSeq) return;
|
||||
setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (statements.length > 1) {
|
||||
message.success(`已执行 ${statements.length} 条语句,生成 ${nextResultSets.length} 个结果集。`);
|
||||
} else if (nextResultSets.length === 0) {
|
||||
message.success('执行成功。');
|
||||
}
|
||||
if (anyTruncated && maxRows > 0) {
|
||||
message.warning(`结果集已自动限制为最多 ${maxRows} 行(可在工具栏调整)。`);
|
||||
// 后端附带的提示信息(如数据源不支持原生多语句执行的回退提示)
|
||||
if (res.message) {
|
||||
message.info(res.message);
|
||||
}
|
||||
if (resultSetDataArray.length > 1) {
|
||||
message.success(`已执行完成,生成 ${nextResultSets.length} 个结果集。`);
|
||||
} else if (nextResultSets.length === 0) {
|
||||
message.success('执行成功。');
|
||||
}
|
||||
if (anyTruncated && maxRows > 0) {
|
||||
message.warning(`结果集已自动限制为最多 ${maxRows} 行(可在工具栏调整)。`);
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
message.error("Error executing query: " + e.message);
|
||||
@@ -1319,6 +1492,46 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleSelectAllInEditor = (event: KeyboardEvent) => {
|
||||
if (activeTabId !== tab.id) {
|
||||
return;
|
||||
}
|
||||
if (!(event.ctrlKey || event.metaKey) || event.altKey || event.shiftKey || event.key.toLowerCase() !== 'a') {
|
||||
return;
|
||||
}
|
||||
|
||||
const editor = editorRef.current;
|
||||
if (!editor) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetNode = event.target instanceof Node ? event.target : null;
|
||||
const editorHasFocus = !!editor.hasTextFocus?.();
|
||||
const inEditorPane = !!(targetNode && editorPaneRef.current?.contains(targetNode));
|
||||
const inQueryEditor = !!(targetNode && queryEditorRootRef.current?.contains(targetNode));
|
||||
if (!editorHasFocus && !inEditorPane) {
|
||||
return;
|
||||
}
|
||||
if (!editorHasFocus && isEditableElement(event.target) && !inEditorPane) {
|
||||
return;
|
||||
}
|
||||
if (!editorHasFocus && !inQueryEditor) {
|
||||
return;
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
editor.focus?.();
|
||||
editor.trigger('keyboard', 'editor.action.selectAll', null);
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleSelectAllInEditor, true);
|
||||
return () => {
|
||||
window.removeEventListener('keydown', handleSelectAllInEditor, true);
|
||||
};
|
||||
}, [activeTabId, tab.id]);
|
||||
|
||||
useEffect(() => {
|
||||
const binding = shortcutOptions.runQuery;
|
||||
if (!binding?.enabled || !binding.combo) {
|
||||
@@ -1361,16 +1574,60 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
};
|
||||
}, [activeTabId, tab.id, handleRun]);
|
||||
|
||||
const resolveDefaultQueryName = () => {
|
||||
const rawTitle = String(tab.title || '').trim();
|
||||
if (!rawTitle || rawTitle.startsWith('新建查询')) {
|
||||
return '未命名查询';
|
||||
}
|
||||
return rawTitle;
|
||||
};
|
||||
|
||||
const persistQuery = (payload: { id: string; name: string; createdAt?: number }) => {
|
||||
const sql = getCurrentQuery();
|
||||
const saved = {
|
||||
id: payload.id,
|
||||
name: payload.name,
|
||||
sql,
|
||||
connectionId: currentConnectionId,
|
||||
dbName: currentDb || tab.dbName || '',
|
||||
createdAt: payload.createdAt ?? Date.now(),
|
||||
};
|
||||
saveQuery(saved);
|
||||
addTab({
|
||||
...tab,
|
||||
title: payload.name,
|
||||
query: sql,
|
||||
connectionId: currentConnectionId,
|
||||
dbName: currentDb || tab.dbName || '',
|
||||
savedQueryId: payload.id,
|
||||
});
|
||||
return saved;
|
||||
};
|
||||
|
||||
const handleQuickSave = () => {
|
||||
const existed = currentSavedQuery || null;
|
||||
const fallbackSavedId = String(tab.savedQueryId || '').trim();
|
||||
const saveId = existed?.id || fallbackSavedId || '';
|
||||
if (!saveId) {
|
||||
saveForm.setFieldsValue({ name: resolveDefaultQueryName() });
|
||||
setIsSaveModalOpen(true);
|
||||
return;
|
||||
}
|
||||
const saveName = existed?.name || resolveDefaultQueryName();
|
||||
persistQuery({ id: saveId, name: saveName, createdAt: existed?.createdAt });
|
||||
message.success('查询已保存!');
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
const values = await saveForm.validateFields();
|
||||
saveQuery({
|
||||
id: tab.id.startsWith('saved-') ? tab.id : `saved-${Date.now()}`,
|
||||
name: values.name,
|
||||
sql: query,
|
||||
connectionId: currentConnectionId,
|
||||
dbName: currentDb || tab.dbName || '',
|
||||
createdAt: Date.now()
|
||||
const existed = currentSavedQuery || null;
|
||||
const fallbackSavedId = String(tab.savedQueryId || '').trim();
|
||||
const nextSavedId = existed?.id || fallbackSavedId || `saved-${Date.now()}`;
|
||||
persistQuery({
|
||||
id: nextSavedId,
|
||||
name: String(values.name || '').trim() || '未命名查询',
|
||||
createdAt: existed?.createdAt,
|
||||
});
|
||||
message.success('查询已保存!');
|
||||
setIsSaveModalOpen(false);
|
||||
@@ -1386,8 +1643,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
setActiveResultKey(prevActive => {
|
||||
if (prevActive && prevActive !== key) return prevActive;
|
||||
const nextKey = next[idx]?.key || next[idx - 1]?.key || next[0]?.key || '';
|
||||
return nextKey;
|
||||
return next[idx]?.key || next[idx - 1]?.key || next[0]?.key || '';
|
||||
});
|
||||
|
||||
return next;
|
||||
@@ -1395,7 +1651,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
};
|
||||
|
||||
return (
|
||||
<div style={{ flex: '1 1 auto', minHeight: 0, display: 'flex', flexDirection: 'column', height: '100%', overflow: 'hidden' }}>
|
||||
<div ref={queryEditorRootRef} style={{ flex: '1 1 auto', minHeight: 0, display: 'flex', flexDirection: 'column', height: '100%', overflow: 'hidden' }}>
|
||||
<style>{`
|
||||
.query-result-tabs {
|
||||
flex: 1 1 auto;
|
||||
@@ -1438,6 +1694,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
transition: none !important;
|
||||
}
|
||||
`}</style>
|
||||
<div ref={editorPaneRef}>
|
||||
<div style={{ padding: '8px', display: 'flex', gap: '8px', flexShrink: 0, alignItems: 'center' }}>
|
||||
<Select
|
||||
style={{ width: 150 }}
|
||||
@@ -1490,10 +1747,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
</Button>
|
||||
)}
|
||||
</Button.Group>
|
||||
<Button icon={<SaveOutlined />} onClick={() => {
|
||||
saveForm.setFieldsValue({ name: tab.title.replace('Query (', '').replace(')', '') });
|
||||
setIsSaveModalOpen(true);
|
||||
}}>
|
||||
<Button icon={<SaveOutlined />} onClick={handleQuickSave}>
|
||||
保存
|
||||
</Button>
|
||||
|
||||
@@ -1512,7 +1766,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
height="100%"
|
||||
defaultLanguage="sql"
|
||||
theme={darkMode ? "transparent-dark" : "transparent-light"}
|
||||
value={query}
|
||||
defaultValue={query}
|
||||
onChange={(val) => setQuery(val || '')}
|
||||
onMount={handleEditorDidMount}
|
||||
options={{
|
||||
@@ -1535,6 +1789,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}}
|
||||
title="拖动调整高度"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', padding: 0, display: 'flex', flexDirection: 'column' }}>
|
||||
{resultSets.length > 0 ? (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
import React, { useEffect, useState, useMemo, useRef } from 'react';
|
||||
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select, Popover, Tooltip } from 'antd';
|
||||
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select, Popover, Tooltip, Progress } from 'antd';
|
||||
import {
|
||||
DatabaseOutlined,
|
||||
TableOutlined,
|
||||
@@ -27,12 +27,17 @@ import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge,
|
||||
DisconnectOutlined,
|
||||
CloudOutlined,
|
||||
CheckSquareOutlined,
|
||||
CodeOutlined
|
||||
CodeOutlined,
|
||||
TagOutlined,
|
||||
CheckOutlined,
|
||||
FilterOutlined
|
||||
} from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { useStore } from '../store';
|
||||
import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
import { SavedConnection } from '../types';
|
||||
import { DBGetDatabases, DBGetTables, DBQuery, DBShowCreateTable, ExportTable, OpenSQLFile, CreateDatabase, RenameDatabase, DropDatabase, RenameTable, DropTable, DropView, DropFunction, RenameView } from '../../wailsjs/go/app/App';
|
||||
import { normalizeOpacityForPlatform } from '../utils/appearance';
|
||||
import { DBGetDatabases, DBGetTables, DBQuery, DBShowCreateTable, ExportTable, OpenSQLFile, ExecuteSQLFile, CancelSQLFileExecution, CreateDatabase, RenameDatabase, DropDatabase, RenameTable, DropTable, DropView, DropFunction, RenameView } from '../../wailsjs/go/app/App';
|
||||
import { EventsOn } from '../../wailsjs/runtime/runtime';
|
||||
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
|
||||
const { Search } = Input;
|
||||
|
||||
@@ -73,9 +78,19 @@ const SEARCH_SCOPE_LABEL_MAP: Record<SearchScope, string> = SEARCH_SCOPE_OPTIONS
|
||||
return acc;
|
||||
}, {} as Record<SearchScope, string>);
|
||||
|
||||
|
||||
const SEARCH_SCOPE_ICON_MAP: Record<SearchScope, React.ReactNode> = {
|
||||
smart: <ThunderboltOutlined />,
|
||||
object: <TableOutlined />,
|
||||
database: <DatabaseOutlined />,
|
||||
host: <CloudOutlined />,
|
||||
tag: <TagOutlined />,
|
||||
};
|
||||
|
||||
const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> = ({ onEditConnection }) => {
|
||||
const connections = useStore(state => state.connections);
|
||||
const savedQueries = useStore(state => state.savedQueries);
|
||||
const deleteQuery = useStore(state => state.deleteQuery);
|
||||
const addConnection = useStore(state => state.addConnection);
|
||||
const addTab = useStore(state => state.addTab);
|
||||
const setActiveContext = useStore(state => state.setActiveContext);
|
||||
@@ -94,8 +109,10 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const tableSortPreference = useStore(state => state.tableSortPreference);
|
||||
const recordTableAccess = useStore(state => state.recordTableAccess);
|
||||
const setTableSortPreference = useStore(state => state.setTableSortPreference);
|
||||
const addSqlLog = useStore(state => state.addSqlLog);
|
||||
const darkMode = theme === 'dark';
|
||||
const opacity = normalizeOpacityForPlatform(appearance.opacity);
|
||||
const resolvedAppearance = resolveAppearanceValues(appearance);
|
||||
const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity);
|
||||
const [treeData, setTreeData] = useState<TreeNode[]>([]);
|
||||
|
||||
// Background Helper (Duplicate logic for now, ideally shared)
|
||||
@@ -108,6 +125,43 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
return `rgba(${r}, ${g}, ${b}, ${opacity})`;
|
||||
};
|
||||
const bgMain = getBg('#141414');
|
||||
const overlayTheme = useMemo(() => buildOverlayWorkbenchTheme(darkMode), [darkMode]);
|
||||
const modalPanelStyle = useMemo(() => ({
|
||||
background: overlayTheme.shellBg,
|
||||
border: overlayTheme.shellBorder,
|
||||
boxShadow: overlayTheme.shellShadow,
|
||||
backdropFilter: overlayTheme.shellBackdropFilter,
|
||||
}), [overlayTheme]);
|
||||
const modalSectionStyle = useMemo(() => ({
|
||||
padding: 14,
|
||||
borderRadius: 14,
|
||||
border: overlayTheme.sectionBorder,
|
||||
background: overlayTheme.sectionBg,
|
||||
}), [overlayTheme]);
|
||||
const modalScrollSectionStyle = useMemo(() => ({
|
||||
maxHeight: 400,
|
||||
overflow: 'auto' as const,
|
||||
border: overlayTheme.sectionBorder,
|
||||
borderRadius: 14,
|
||||
padding: 12,
|
||||
background: overlayTheme.sectionBg,
|
||||
}), [overlayTheme]);
|
||||
const modalHintTextStyle = useMemo(() => ({
|
||||
color: overlayTheme.mutedText,
|
||||
fontSize: 12,
|
||||
lineHeight: 1.6,
|
||||
}), [overlayTheme]);
|
||||
const renderSidebarModalTitle = (icon: React.ReactNode, title: string, description: string) => (
|
||||
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 12 }}>
|
||||
<div style={{ width: 34, height: 34, borderRadius: 12, display: 'grid', placeItems: 'center', background: overlayTheme.iconBg, color: overlayTheme.iconColor, flexShrink: 0 }}>
|
||||
{icon}
|
||||
</div>
|
||||
<div style={{ minWidth: 0 }}>
|
||||
<div style={{ fontSize: 16, fontWeight: 700, color: overlayTheme.titleText }}>{title}</div>
|
||||
<div style={{ marginTop: 4, color: overlayTheme.mutedText, fontSize: 12, lineHeight: 1.6 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
const [searchValue, setSearchValue] = useState('');
|
||||
const [searchScopes, setSearchScopes] = useState<SearchScope[]>(['smart']);
|
||||
const [isSearchScopePopoverOpen, setIsSearchScopePopoverOpen] = useState(false);
|
||||
@@ -382,6 +436,16 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
password: readString(rawProxy.password, rawProxy.Password, cloned.proxyPassword, cloned.ProxyPassword),
|
||||
};
|
||||
const hasProxyDetail = Boolean(normalizedProxy.host || normalizedProxy.user || normalizedProxy.password);
|
||||
const rawHttpTunnel = (cloned.httpTunnel ?? cloned.HTTPTunnel ?? {}) as Record<string, unknown>;
|
||||
const normalizedHttpTunnel = {
|
||||
host: readString(rawHttpTunnel.host, rawHttpTunnel.Host, cloned.httpTunnelHost, cloned.HttpTunnelHost),
|
||||
port: readNumber(8080, rawHttpTunnel.port, rawHttpTunnel.Port, cloned.httpTunnelPort, cloned.HttpTunnelPort),
|
||||
user: readString(rawHttpTunnel.user, rawHttpTunnel.User, cloned.httpTunnelUser, cloned.HttpTunnelUser),
|
||||
password: readString(rawHttpTunnel.password, rawHttpTunnel.Password, cloned.httpTunnelPassword, cloned.HttpTunnelPassword),
|
||||
};
|
||||
const hasHttpTunnelDetail = Boolean(normalizedHttpTunnel.host || normalizedHttpTunnel.user || normalizedHttpTunnel.password);
|
||||
const normalizedUseHttpTunnel = readBool(hasHttpTunnelDetail, cloned.useHttpTunnel, cloned.UseHTTPTunnel);
|
||||
const normalizedUseProxy = !normalizedUseHttpTunnel && readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy);
|
||||
|
||||
const rawHosts = Array.isArray(cloned.hosts)
|
||||
? cloned.hosts
|
||||
@@ -394,8 +458,10 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
...(cloned as SavedConnection['config']),
|
||||
useSSH: readBool(hasSSHDetail, cloned.useSSH, cloned.UseSSH),
|
||||
ssh: normalizedSSH,
|
||||
useProxy: readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy),
|
||||
useProxy: normalizedUseProxy,
|
||||
proxy: normalizedProxy,
|
||||
useHttpTunnel: normalizedUseHttpTunnel,
|
||||
httpTunnel: normalizedHttpTunnel,
|
||||
hosts: normalizedHosts,
|
||||
timeout: readNumber(30, cloned.timeout, cloned.Timeout),
|
||||
};
|
||||
@@ -645,10 +711,15 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
case 'oracle':
|
||||
case 'dm':
|
||||
if (!safeDbName) {
|
||||
return [{ sql: `SELECT VIEW_NAME AS view_name FROM USER_VIEWS ORDER BY VIEW_NAME` }];
|
||||
}
|
||||
return [{ sql: `SELECT OWNER AS schema_name, VIEW_NAME AS view_name FROM ALL_VIEWS WHERE OWNER = '${safeDbName.toUpperCase()}' ORDER BY VIEW_NAME` }];
|
||||
return normalizeMetadataQuerySpecs([
|
||||
{ sql: `SELECT VIEW_NAME AS view_name FROM USER_VIEWS ORDER BY VIEW_NAME` },
|
||||
{ sql: `SELECT OWNER AS schema_name, VIEW_NAME AS view_name FROM ALL_VIEWS WHERE OWNER = USER ORDER BY VIEW_NAME` },
|
||||
{
|
||||
sql: safeDbName
|
||||
? `SELECT OWNER AS schema_name, VIEW_NAME AS view_name FROM ALL_VIEWS WHERE OWNER = '${safeDbName.toUpperCase()}' ORDER BY VIEW_NAME`
|
||||
: '',
|
||||
},
|
||||
]);
|
||||
case 'sqlite':
|
||||
return [{ sql: `SELECT name AS view_name FROM sqlite_master WHERE type = 'view' ORDER BY name` }];
|
||||
case 'duckdb':
|
||||
@@ -724,17 +795,35 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
case 'kingbase':
|
||||
case 'highgo':
|
||||
case 'vastbase':
|
||||
return [{ sql: `SELECT n.nspname AS schema_name, p.proname AS routine_name, CASE WHEN p.prokind = 'p' THEN 'PROCEDURE' ELSE 'FUNCTION' END AS routine_type FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname NOT LIKE 'pg_%' ORDER BY n.nspname, routine_type, p.proname` }];
|
||||
return normalizeMetadataQuerySpecs([
|
||||
{
|
||||
// PostgreSQL 11+ / 部分 PG-like:通过 prokind 区分 FUNCTION/PROCEDURE
|
||||
sql: `SELECT n.nspname AS schema_name, p.proname AS routine_name, CASE WHEN p.prokind = 'p' THEN 'PROCEDURE' ELSE 'FUNCTION' END AS routine_type FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname NOT LIKE 'pg_%' ORDER BY n.nspname, routine_type, p.proname`,
|
||||
},
|
||||
{
|
||||
// PostgreSQL 10 / 不支持 prokind 的兼容路径
|
||||
sql: `SELECT r.routine_schema AS schema_name, r.routine_name AS routine_name, COALESCE(NULLIF(UPPER(r.routine_type), ''), 'FUNCTION') AS routine_type FROM information_schema.routines r WHERE r.routine_schema NOT IN ('pg_catalog', 'information_schema') AND r.routine_schema NOT LIKE 'pg_%' ORDER BY r.routine_schema, routine_type, r.routine_name`,
|
||||
},
|
||||
{
|
||||
// 最后兜底:仅函数列表,确保 prokind/routines 视图异常时仍可展示
|
||||
sql: `SELECT n.nspname AS schema_name, p.proname AS routine_name, 'FUNCTION' AS routine_type FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname NOT LIKE 'pg_%' ORDER BY n.nspname, p.proname`,
|
||||
},
|
||||
]);
|
||||
case 'sqlserver': {
|
||||
const safeDb = quoteSqlServerIdentifier(dbName || 'master');
|
||||
return [{ sql: `SELECT s.name AS schema_name, o.name AS routine_name, CASE o.type WHEN 'P' THEN 'PROCEDURE' WHEN 'FN' THEN 'FUNCTION' WHEN 'IF' THEN 'FUNCTION' WHEN 'TF' THEN 'FUNCTION' END AS routine_type FROM ${safeDb}.sys.objects o JOIN ${safeDb}.sys.schemas s ON o.schema_id = s.schema_id WHERE o.type IN ('P','FN','IF','TF') ORDER BY o.type, s.name, o.name` }];
|
||||
}
|
||||
case 'oracle':
|
||||
case 'dm':
|
||||
if (!safeDbName) {
|
||||
return [{ sql: `SELECT OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM USER_OBJECTS WHERE OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` }];
|
||||
}
|
||||
return [{ sql: `SELECT OWNER AS schema_name, OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM ALL_OBJECTS WHERE OWNER = '${safeDbName.toUpperCase()}' AND OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` }];
|
||||
return normalizeMetadataQuerySpecs([
|
||||
{ sql: `SELECT OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM USER_OBJECTS WHERE OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` },
|
||||
{ sql: `SELECT OWNER AS schema_name, OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM ALL_OBJECTS WHERE OWNER = USER AND OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME` },
|
||||
{
|
||||
sql: safeDbName
|
||||
? `SELECT OWNER AS schema_name, OBJECT_NAME AS routine_name, OBJECT_TYPE AS routine_type FROM ALL_OBJECTS WHERE OWNER = '${safeDbName.toUpperCase()}' AND OBJECT_TYPE IN ('FUNCTION','PROCEDURE') ORDER BY OBJECT_TYPE, OBJECT_NAME`
|
||||
: '',
|
||||
},
|
||||
]);
|
||||
case 'duckdb':
|
||||
return [{
|
||||
sql: `SELECT schema_name, function_name AS routine_name, 'FUNCTION' AS routine_type FROM duckdb_functions() WHERE internal = false AND lower(function_type) = 'macro' AND COALESCE(macro_definition, '') <> '' ORDER BY schema_name, function_name`,
|
||||
@@ -1393,7 +1482,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
type: 'query',
|
||||
connectionId: q.connectionId,
|
||||
dbName: q.dbName,
|
||||
query: q.sql
|
||||
query: q.sql,
|
||||
savedQueryId: q.id,
|
||||
});
|
||||
return;
|
||||
} else if (node.type === 'redis-db') {
|
||||
@@ -1474,7 +1564,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
};
|
||||
@@ -1497,7 +1587,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
@@ -1524,7 +1614,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
@@ -1713,7 +1803,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
} else {
|
||||
message.success('导出成功');
|
||||
}
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
@@ -1722,6 +1812,94 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
};
|
||||
|
||||
const handleBatchClear = async () => {
|
||||
const selectedObjects = batchTables.filter(t => checkedTableKeys.includes(t.key));
|
||||
if (selectedObjects.length === 0) {
|
||||
message.warning('请至少选择一个对象');
|
||||
return;
|
||||
}
|
||||
|
||||
const { conn, dbName } = batchDbContext;
|
||||
const objectNames = selectedObjects.map(t => t.objectName);
|
||||
|
||||
const ok = await new Promise<boolean>((resolve) => {
|
||||
Modal.confirm({
|
||||
title: '确认清空选中表',
|
||||
content: `清空选中表会永久删除表中所有数据,操作不可逆,是否继续?\r\n\r\n连接: ${conn.name}\n数据库: ${dbName}`,
|
||||
okText: '继续',
|
||||
cancelText: '取消',
|
||||
onOk: () => resolve(true),
|
||||
onCancel: () => resolve(false),
|
||||
});
|
||||
});
|
||||
if (!ok) return;
|
||||
|
||||
setIsBatchModalOpen(false);
|
||||
const hide = message.loading(`正在清空选中表 (${objectNames.length})...`, 0);
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const app = (window as any).go.app.App;
|
||||
const res = await app.TruncateTables(normalizeConnConfig(conn.config), dbName, objectNames);
|
||||
hide();
|
||||
const duration = Date.now() - startTime;
|
||||
if (res.success) {
|
||||
message.success('清空成功');
|
||||
// 构造 SQL 日志
|
||||
let logSql = `/* Truncate Tables (${objectNames.length} tables) */\n`;
|
||||
if (res.data && res.data.executedSQLs && Array.isArray(res.data.executedSQLs)) {
|
||||
logSql += res.data.executedSQLs.join(';\n') + ';';
|
||||
} else {
|
||||
logSql += objectNames.map(name => name).join('; ');
|
||||
}
|
||||
addSqlLog({
|
||||
id: Date.now().toString(),
|
||||
timestamp: Date.now(),
|
||||
sql: logSql,
|
||||
status: 'success',
|
||||
duration,
|
||||
message: res.message,
|
||||
dbName,
|
||||
affectedRows: res.data?.count || 0
|
||||
});
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error('清空失败: ' + res.message);
|
||||
// 记录失败的日志
|
||||
let logSql = `/* Truncate Tables (${objectNames.length} tables) - FAILED */\n`;
|
||||
if (res.data && res.data.executedSQLs && Array.isArray(res.data.executedSQLs)) {
|
||||
logSql += res.data.executedSQLs.join(';\n') + ';';
|
||||
} else {
|
||||
logSql += objectNames.map(name => name).join('; ');
|
||||
}
|
||||
addSqlLog({
|
||||
id: Date.now().toString(),
|
||||
timestamp: Date.now(),
|
||||
sql: logSql,
|
||||
status: 'error',
|
||||
duration,
|
||||
message: res.message,
|
||||
dbName
|
||||
});
|
||||
}
|
||||
} catch (e: any) {
|
||||
const duration = Date.now() - startTime;
|
||||
hide();
|
||||
const errMsg = e?.message || String(e);
|
||||
message.error('清空失败: ' + errMsg);
|
||||
// 记录异常的日志
|
||||
let logSql = `/* Truncate Tables (${objectNames.length} tables) - ERROR */\n`;
|
||||
logSql += objectNames.map(name => name).join('; ');
|
||||
addSqlLog({
|
||||
id: Date.now().toString(),
|
||||
timestamp: Date.now(),
|
||||
sql: logSql,
|
||||
status: 'error',
|
||||
duration,
|
||||
message: errMsg,
|
||||
dbName
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleCheckAll = (checked: boolean) => {
|
||||
if (batchSelectionScope === 'all') {
|
||||
setCheckedTableKeys(checked ? allBatchObjectKeys : []);
|
||||
@@ -1853,7 +2031,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success(`${db.dbName} 导出成功`);
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error(`${db.dbName} 导出失败: ` + res.message);
|
||||
break;
|
||||
} else {
|
||||
@@ -1882,23 +2060,127 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
|
||||
const handleRunSQLFile = async (node: any) => {
|
||||
const res = await (window as any).go.app.App.OpenSQLFile();
|
||||
const res = await OpenSQLFile();
|
||||
if (res.success) {
|
||||
const sqlContent = res.data;
|
||||
const data = res.data;
|
||||
// 大文件:后端返回文件路径,走流式执行
|
||||
if (data && typeof data === 'object' && data.isLargeFile) {
|
||||
const connId = node.type === 'connection' ? node.key : node.dataRef?.id;
|
||||
const dbName = node.dataRef?.dbName || '';
|
||||
const conn = connections.find(c => c.id === connId);
|
||||
if (!conn) {
|
||||
message.error('未找到对应的连接配置');
|
||||
return;
|
||||
}
|
||||
startSQLFileExecution(conn.config, dbName, data.filePath, data.fileSizeMB);
|
||||
return;
|
||||
}
|
||||
// 小文件:加载到编辑器
|
||||
const sqlContent = data;
|
||||
const { dbName, id } = node.dataRef;
|
||||
addTab({
|
||||
id: `query-${Date.now()}`,
|
||||
title: `Import SQL`,
|
||||
title: `运行外部SQL文件`,
|
||||
type: 'query',
|
||||
connectionId: node.type === 'connection' ? node.key : node.dataRef.id,
|
||||
dbName: dbName,
|
||||
query: sqlContent
|
||||
});
|
||||
} else if (res.message !== "Cancelled") {
|
||||
message.error("读取文件失败: " + res.message);
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error('读取文件失败: ' + res.message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenSQLFileFromToolbar = async () => {
|
||||
const ctx = useStore.getState().activeContext;
|
||||
if (!ctx?.connectionId) {
|
||||
message.warning('请先选择一个连接或数据库');
|
||||
return;
|
||||
}
|
||||
const res = await OpenSQLFile();
|
||||
if (res.success) {
|
||||
const data = res.data;
|
||||
// 大文件:后端流式执行
|
||||
if (data && typeof data === 'object' && data.isLargeFile) {
|
||||
const conn = connections.find(c => c.id === ctx.connectionId);
|
||||
if (!conn) {
|
||||
message.error('未找到对应的连接配置');
|
||||
return;
|
||||
}
|
||||
startSQLFileExecution(conn.config, ctx.dbName || '', data.filePath, data.fileSizeMB);
|
||||
return;
|
||||
}
|
||||
// 小文件
|
||||
addTab({
|
||||
id: `query-${Date.now()}`,
|
||||
title: `运行外部SQL文件`,
|
||||
type: 'query',
|
||||
connectionId: ctx.connectionId,
|
||||
dbName: ctx.dbName || undefined,
|
||||
query: data
|
||||
});
|
||||
} else if (res.message !== '已取消') {
|
||||
message.error('读取文件失败: ' + res.message);
|
||||
}
|
||||
};
|
||||
|
||||
// SQL 文件流式执行状态
|
||||
const [sqlFileExecState, setSqlFileExecState] = useState<{
|
||||
open: boolean;
|
||||
jobId: string;
|
||||
fileSizeMB: string;
|
||||
status: 'running' | 'done' | 'cancelled' | 'error';
|
||||
executed: number;
|
||||
failed: number;
|
||||
total: number;
|
||||
percent: number;
|
||||
currentSQL: string;
|
||||
resultMessage: string;
|
||||
}>({
|
||||
open: false, jobId: '', fileSizeMB: '', status: 'running',
|
||||
executed: 0, failed: 0, total: 0, percent: 0, currentSQL: '', resultMessage: ''
|
||||
});
|
||||
|
||||
const startSQLFileExecution = (config: any, dbName: string, filePath: string, fileSizeMB: string) => {
|
||||
const jobId = `sqlfile-${Date.now()}`;
|
||||
setSqlFileExecState({
|
||||
open: true, jobId, fileSizeMB, status: 'running',
|
||||
executed: 0, failed: 0, total: 0, percent: 0, currentSQL: '', resultMessage: ''
|
||||
});
|
||||
|
||||
// 监听进度事件
|
||||
const offProgress = EventsOn('sqlfile:progress', (event: any) => {
|
||||
if (!event || event.jobId !== jobId) return;
|
||||
setSqlFileExecState(prev => ({
|
||||
...prev,
|
||||
status: event.status || prev.status,
|
||||
executed: typeof event.executed === 'number' ? event.executed : prev.executed,
|
||||
failed: typeof event.failed === 'number' ? event.failed : prev.failed,
|
||||
total: typeof event.total === 'number' ? event.total : prev.total,
|
||||
percent: typeof event.percent === 'number' ? Math.min(100, event.percent) : prev.percent,
|
||||
currentSQL: typeof event.currentSQL === 'string' ? event.currentSQL : prev.currentSQL,
|
||||
}));
|
||||
});
|
||||
|
||||
// 异步执行
|
||||
ExecuteSQLFile(config, dbName, filePath, jobId).then(res => {
|
||||
offProgress();
|
||||
setSqlFileExecState(prev => ({
|
||||
...prev,
|
||||
status: res.success ? 'done' : (prev.status === 'cancelled' ? 'cancelled' : 'error'),
|
||||
percent: 100,
|
||||
resultMessage: res.message || '',
|
||||
}));
|
||||
}).catch(err => {
|
||||
offProgress();
|
||||
setSqlFileExecState(prev => ({
|
||||
...prev,
|
||||
status: 'error',
|
||||
resultMessage: String(err?.message || err),
|
||||
}));
|
||||
});
|
||||
};
|
||||
|
||||
const handleCreateDatabase = async () => {
|
||||
try {
|
||||
const values = await createDbForm.validateFields();
|
||||
@@ -2449,32 +2731,98 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const searchScopePopoverContent = useMemo(() => {
|
||||
const smartSelected = searchScopes.includes('smart');
|
||||
const scopedOptions = SEARCH_SCOPE_OPTIONS.filter((option) => option.value !== 'smart');
|
||||
const borderColor = overlayTheme.sectionBorder.replace('1px solid ', '');
|
||||
const mutedTextColor = overlayTheme.mutedText;
|
||||
const titleColor = overlayTheme.titleText;
|
||||
const panelBg = overlayTheme.shellBg;
|
||||
const smartBg = smartSelected
|
||||
? (darkMode ? 'linear-gradient(135deg, rgba(255,214,102,0.22) 0%, rgba(255,179,71,0.16) 100%)' : 'linear-gradient(135deg, rgba(255,214,102,0.26) 0%, rgba(255,244,204,0.92) 100%)')
|
||||
: (darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(255,255,255,0.72)');
|
||||
const smartBorder = smartSelected
|
||||
? (darkMode ? 'rgba(255,214,102,0.42)' : 'rgba(245,176,65,0.34)')
|
||||
: borderColor;
|
||||
const getOptionCardStyle = (checked: boolean) => ({
|
||||
display: 'flex',
|
||||
alignItems: 'center' as const,
|
||||
justifyContent: 'space-between' as const,
|
||||
gap: 12,
|
||||
padding: '10px 12px',
|
||||
borderRadius: 12,
|
||||
border: `1px solid ${checked ? (darkMode ? 'rgba(118,169,250,0.44)' : 'rgba(24,144,255,0.32)') : borderColor}`,
|
||||
background: checked
|
||||
? (darkMode ? 'rgba(64,124,255,0.18)' : 'rgba(24,144,255,0.08)')
|
||||
: (darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(255,255,255,0.76)'),
|
||||
transition: 'all 120ms ease',
|
||||
});
|
||||
return (
|
||||
<div style={{ minWidth: 220, display: 'flex', flexDirection: 'column', gap: 8 }}>
|
||||
<div style={{ fontSize: 12, color: '#8c8c8c' }}>搜索范围</div>
|
||||
<Checkbox
|
||||
checked={smartSelected}
|
||||
onChange={(e) => setSearchScopeChecked('smart', e.target.checked)}
|
||||
>
|
||||
智能(推荐)
|
||||
</Checkbox>
|
||||
<div style={{ paddingLeft: 12, display: 'grid', gap: 6 }}>
|
||||
{scopedOptions.map((option) => (
|
||||
<Checkbox
|
||||
key={option.value}
|
||||
checked={searchScopes.includes(option.value)}
|
||||
onChange={(e) => setSearchScopeChecked(option.value, e.target.checked)}
|
||||
>
|
||||
{option.label}
|
||||
</Checkbox>
|
||||
))}
|
||||
<div style={{ minWidth: 280, display: 'flex', flexDirection: 'column', background: panelBg, padding: 14, gap: 12 }}>
|
||||
<div style={{ display: 'flex', alignItems: 'flex-start', justifyContent: 'space-between', gap: 12 }}>
|
||||
<div>
|
||||
<div style={{ fontSize: 12, fontWeight: 700, letterSpacing: 0.4, color: mutedTextColor, textTransform: 'uppercase' }}>搜索范围</div>
|
||||
<div style={{ marginTop: 4, fontSize: 13, lineHeight: 1.5, color: mutedTextColor }}>“智能”自动匹配最可能的命中项;手动模式支持按维度组合筛选。</div>
|
||||
</div>
|
||||
<div style={{ width: 32, height: 32, borderRadius: 10, display: 'grid', placeItems: 'center', background: darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(17,24,39,0.06)', color: darkMode ? '#ffd666' : '#1677ff', flexShrink: 0 }}>
|
||||
<FilterOutlined />
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
|
||||
智能与其他项互斥;其他项支持多选。
|
||||
|
||||
<label style={{ display: 'block', cursor: 'pointer' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 12, padding: '12px 14px', borderRadius: 14, border: `1px solid ${smartBorder}`, background: smartBg, boxShadow: smartSelected ? (darkMode ? '0 10px 24px rgba(0,0,0,0.24)' : '0 10px 24px rgba(245,176,65,0.14)') : 'none' }}>
|
||||
<Checkbox
|
||||
checked={smartSelected}
|
||||
onChange={(e) => setSearchScopeChecked('smart', e.target.checked)}
|
||||
/>
|
||||
<div style={{ width: 30, height: 30, borderRadius: 10, display: 'grid', placeItems: 'center', background: darkMode ? 'rgba(255,214,102,0.16)' : 'rgba(255,214,102,0.3)', color: darkMode ? '#ffd666' : '#ad6800', flexShrink: 0 }}>
|
||||
{SEARCH_SCOPE_ICON_MAP.smart}
|
||||
</div>
|
||||
<div style={{ flex: 1, minWidth: 0 }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8, flexWrap: 'wrap' }}>
|
||||
<span style={{ fontSize: 14, fontWeight: 700, color: titleColor }}>智能</span>
|
||||
<span style={{ padding: '2px 8px', borderRadius: 999, fontSize: 11, fontWeight: 700, color: darkMode ? '#ffe58f' : '#ad6800', background: darkMode ? 'rgba(255,214,102,0.16)' : 'rgba(255,214,102,0.35)' }}>推荐</span>
|
||||
</div>
|
||||
<div style={{ marginTop: 3, fontSize: 12, lineHeight: 1.5, color: mutedTextColor }}>适合日常检索,自动覆盖名称、库、Host 和标签等高频维度。</div>
|
||||
</div>
|
||||
</div>
|
||||
</label>
|
||||
|
||||
<div style={{ height: 1, background: overlayTheme.divider, opacity: 0.9 }} />
|
||||
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', gap: 12 }}>
|
||||
<div style={{ fontSize: 12, fontWeight: 700, letterSpacing: 0.3, color: mutedTextColor, textTransform: 'uppercase' }}>手动范围</div>
|
||||
<div style={{ fontSize: 12, color: mutedTextColor }}>支持多选组合</div>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'grid', gap: 8 }}>
|
||||
{scopedOptions.map((option) => {
|
||||
const checked = searchScopes.includes(option.value);
|
||||
return (
|
||||
<label key={option.value} style={{ display: 'block', cursor: 'pointer' }}>
|
||||
<div style={getOptionCardStyle(checked)}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 12, minWidth: 0 }}>
|
||||
<Checkbox
|
||||
checked={checked}
|
||||
onChange={(e) => setSearchScopeChecked(option.value, e.target.checked)}
|
||||
/>
|
||||
<div style={{ width: 28, height: 28, borderRadius: 9, display: 'grid', placeItems: 'center', background: checked ? (darkMode ? 'rgba(118,169,250,0.2)' : 'rgba(24,144,255,0.12)') : (darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(17,24,39,0.06)'), color: checked ? (darkMode ? '#91caff' : '#1677ff') : mutedTextColor, flexShrink: 0 }}>
|
||||
{SEARCH_SCOPE_ICON_MAP[option.value]}
|
||||
</div>
|
||||
<span style={{ fontSize: 14, fontWeight: 600, color: titleColor, whiteSpace: 'nowrap' }}>{option.label}</span>
|
||||
</div>
|
||||
<div style={{ width: 18, display: 'flex', justifyContent: 'center', color: checked ? (darkMode ? '#91caff' : '#1677ff') : 'transparent', flexShrink: 0 }}>
|
||||
<CheckOutlined />
|
||||
</div>
|
||||
</div>
|
||||
</label>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
<div style={{ padding: '10px 12px', borderRadius: 12, background: darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(17,24,39,0.04)', color: mutedTextColor, fontSize: 12, lineHeight: 1.6 }}>
|
||||
智能与其他项互斥。若你明确知道要搜的是对象、库、Host 或标签,建议切到手动范围以减少噪音结果。
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}, [searchScopes]);
|
||||
}, [darkMode, overlayTheme, searchScopes]);
|
||||
|
||||
const parseHostOnlyToken = (value: unknown): string[] => {
|
||||
const raw = String(value || '').trim();
|
||||
@@ -2829,6 +3177,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
});
|
||||
}
|
||||
},
|
||||
{
|
||||
key: 'open-sql-file',
|
||||
label: '运行外部SQL文件',
|
||||
icon: <FileAddOutlined />,
|
||||
onClick: () => handleRunSQLFile(node)
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'edit',
|
||||
@@ -3015,7 +3369,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
},
|
||||
{
|
||||
key: 'run-sql',
|
||||
label: '运行 SQL 文件...',
|
||||
label: '运行外部SQL文件',
|
||||
icon: <FileAddOutlined />,
|
||||
onClick: () => handleRunSQLFile(node)
|
||||
}
|
||||
@@ -3107,13 +3461,15 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
label: '新建查询',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => {
|
||||
const tableName = String(node.dataRef?.tableName || '').trim();
|
||||
const queryTemplate = tableName ? `SELECT * FROM ${tableName};` : 'SELECT * FROM ';
|
||||
addTab({
|
||||
id: `query-${Date.now()}`,
|
||||
title: `新建查询`,
|
||||
type: 'query',
|
||||
connectionId: node.dataRef.id,
|
||||
dbName: node.dataRef.dbName,
|
||||
query: ''
|
||||
query: queryTemplate
|
||||
});
|
||||
}
|
||||
},
|
||||
@@ -3170,6 +3526,56 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
// 已存查询节点的右键菜单
|
||||
if (node.type === 'saved-query') {
|
||||
const q = node.dataRef;
|
||||
return [
|
||||
{
|
||||
key: 'open-query',
|
||||
label: '打开查询',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => {
|
||||
addTab({
|
||||
id: q.id,
|
||||
title: q.name,
|
||||
type: 'query',
|
||||
connectionId: q.connectionId,
|
||||
dbName: q.dbName,
|
||||
query: q.sql,
|
||||
savedQueryId: q.id,
|
||||
});
|
||||
}
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'delete-query',
|
||||
label: '删除查询',
|
||||
icon: <DeleteOutlined />,
|
||||
danger: true,
|
||||
onClick: () => {
|
||||
Modal.confirm({
|
||||
title: '确认删除',
|
||||
content: `确定要删除已保存的查询 "${q.name}" 吗?此操作不可恢复。`,
|
||||
okButtonProps: { danger: true },
|
||||
onOk: () => {
|
||||
deleteQuery(q.id);
|
||||
// 从树中移除节点
|
||||
setTreeData(origin => {
|
||||
const removeNode = (list: TreeNode[]): TreeNode[] =>
|
||||
list
|
||||
.filter(n => n.key !== node.key)
|
||||
.map(n => n.children ? { ...n, children: removeNode(n.children) } : n);
|
||||
return removeNode(origin);
|
||||
});
|
||||
message.success('查询已删除');
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
return [];
|
||||
};
|
||||
|
||||
@@ -3279,14 +3685,14 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
<div style={{ padding: '4px 8px' }}>
|
||||
<Space.Compact block size="small">
|
||||
<div style={{ padding: '4px 10px' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
|
||||
<Search
|
||||
ref={searchInputRef}
|
||||
placeholder="搜索..."
|
||||
onChange={onSearch}
|
||||
size="small"
|
||||
style={{ width: '100%' }}
|
||||
style={{ flex: 1, minWidth: 0 }}
|
||||
/>
|
||||
<Popover
|
||||
content={searchScopePopoverContent}
|
||||
@@ -3294,18 +3700,66 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
placement="bottomRight"
|
||||
open={isSearchScopePopoverOpen}
|
||||
onOpenChange={setIsSearchScopePopoverOpen}
|
||||
styles={{ body: { padding: 0, borderRadius: 18, overflow: 'hidden' } }}
|
||||
>
|
||||
<Tooltip title={`搜索范围:${searchScopeSummary}`}>
|
||||
<Button size="small" icon={<DownOutlined />} style={{ width: 86 }}>
|
||||
范围{searchScopes.includes('smart') ? '(智)' : `(${searchScopes.length})`}
|
||||
<Button
|
||||
size="small"
|
||||
style={{
|
||||
minWidth: 86,
|
||||
display: 'inline-flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 6,
|
||||
paddingInline: 10,
|
||||
borderRadius: 10,
|
||||
borderColor: darkMode ? 'rgba(255,255,255,0.12)' : 'rgba(16,24,40,0.12)',
|
||||
background: darkMode ? bgMain : 'rgba(255,255,255,0.92)',
|
||||
color: darkMode ? 'rgba(255,255,255,0.88)' : '#162033',
|
||||
boxShadow: isSearchScopePopoverOpen
|
||||
? (darkMode ? '0 0 0 1px rgba(255,214,102,0.22) inset' : '0 0 0 1px rgba(24,144,255,0.24) inset')
|
||||
: 'none',
|
||||
backdropFilter: darkMode ? 'blur(10px)' : 'none',
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<span style={{ display: 'inline-flex', alignItems: 'center', color: searchScopes.includes('smart') ? '#ffd666' : (darkMode ? 'rgba(255,255,255,0.72)' : 'rgba(22,32,51,0.72)') }}>
|
||||
<FilterOutlined />
|
||||
</span>
|
||||
<span style={{ fontWeight: 700, color: darkMode ? 'rgba(255,255,255,0.88)' : '#162033' }}>筛选</span>
|
||||
<span
|
||||
style={{
|
||||
minWidth: 18,
|
||||
height: 18,
|
||||
padding: '0 5px',
|
||||
borderRadius: 999,
|
||||
display: 'inline-flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
fontSize: 11,
|
||||
fontWeight: 700,
|
||||
lineHeight: 1,
|
||||
background: searchScopes.includes('smart')
|
||||
? (darkMode ? 'rgba(255,214,102,0.16)' : 'rgba(24,144,255,0.12)')
|
||||
: (darkMode ? 'rgba(118,169,250,0.18)' : 'rgba(24,144,255,0.12)'),
|
||||
color: searchScopes.includes('smart')
|
||||
? (darkMode ? '#ffd666' : '#1677ff')
|
||||
: (darkMode ? '#91caff' : '#1677ff'),
|
||||
}}
|
||||
>
|
||||
{searchScopes.includes('smart') ? '智' : searchScopes.length}
|
||||
</span>
|
||||
<span style={{ display: 'inline-flex', alignItems: 'center', color: darkMode ? 'rgba(255,255,255,0.48)' : 'rgba(22,32,51,0.4)', fontSize: 12 }}>
|
||||
<DownOutlined />
|
||||
</span>
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</Popover>
|
||||
</Space.Compact>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Toolbar */}
|
||||
<div style={{ padding: '4px 8px', borderBottom: 'none', display: 'flex', flexWrap: 'wrap', gap: 4 }}>
|
||||
<div style={{ padding: '4px 10px', borderBottom: 'none', display: 'flex', flexWrap: 'wrap', gap: 4 }}>
|
||||
<Button
|
||||
size="small"
|
||||
icon={<FolderOpenOutlined />}
|
||||
@@ -3334,6 +3788,14 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
>
|
||||
批量操作库
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
icon={<FileAddOutlined />}
|
||||
onClick={handleOpenSQLFileFromToolbar}
|
||||
style={{ flex: '1 1 auto' }}
|
||||
>
|
||||
运行外部SQL文件
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div ref={treeContainerRef} style={{ flex: 1, overflow: 'hidden', minHeight: 0 }}>
|
||||
@@ -3373,8 +3835,14 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
)}
|
||||
|
||||
<Modal
|
||||
title={renameViewTarget?.type === 'tag' ? "编辑标签" : "新建组"}
|
||||
title={renderSidebarModalTitle(
|
||||
<FolderOpenOutlined />,
|
||||
renameViewTarget?.type === 'tag' ? "编辑标签" : "新建组",
|
||||
renameViewTarget?.type === 'tag' ? "调整分组名称和包含的连接。" : "为连接树创建一个更清晰的分组视图。"
|
||||
)}
|
||||
open={isCreateTagModalOpen}
|
||||
centered
|
||||
styles={{ content: modalPanelStyle, header: { background: 'transparent', borderBottom: 'none', paddingBottom: 10 }, body: { paddingTop: 8 }, footer: { background: 'transparent', borderTop: 'none', paddingTop: 12 } }}
|
||||
onOk={() => {
|
||||
createTagForm.validateFields().then(values => {
|
||||
if (renameViewTarget?.type === 'tag') {
|
||||
@@ -3409,20 +3877,24 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
onCancel={() => setIsCreateTagModalOpen(false)}
|
||||
>
|
||||
<Form form={createTagForm} layout="vertical">
|
||||
<Form.Item name="name" label="标签名称" rules={[{ required: true, message: '请输入标签名称' }]}>
|
||||
<Input />
|
||||
</Form.Item>
|
||||
<Form.Item name="connectionIds" label="选择连接">
|
||||
<Checkbox.Group style={{ width: '100%' }}>
|
||||
<Space direction="vertical" style={{ width: '100%', maxHeight: '400px', overflowY: 'auto' }}>
|
||||
{connections.map(conn => (
|
||||
<Checkbox key={conn.id} value={conn.id}>
|
||||
{conn.name} {conn.config.host ? `(${conn.config.host})` : ''}
|
||||
</Checkbox>
|
||||
))}
|
||||
</Space>
|
||||
</Checkbox.Group>
|
||||
</Form.Item>
|
||||
<div style={modalSectionStyle}>
|
||||
<Form.Item name="name" label="标签名称" rules={[{ required: true, message: '请输入标签名称' }]}>
|
||||
<Input placeholder="例如:线上环境 / 核心业务 / 临时调试" />
|
||||
</Form.Item>
|
||||
<Form.Item name="connectionIds" label="选择连接" style={{ marginBottom: 0 }}>
|
||||
<Checkbox.Group style={{ width: '100%' }}>
|
||||
<div style={modalScrollSectionStyle}>
|
||||
<Space direction="vertical" style={{ width: '100%' }}>
|
||||
{connections.map(conn => (
|
||||
<Checkbox key={conn.id} value={conn.id}>
|
||||
{conn.name} {conn.config.host ? `(${conn.config.host})` : ''}
|
||||
</Checkbox>
|
||||
))}
|
||||
</Space>
|
||||
</div>
|
||||
</Checkbox.Group>
|
||||
</Form.Item>
|
||||
</div>
|
||||
</Form>
|
||||
</Modal>
|
||||
|
||||
@@ -3492,16 +3964,27 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</Modal>
|
||||
|
||||
<Modal
|
||||
title="批量操作表"
|
||||
title={renderSidebarModalTitle(<TableOutlined />, "批量操作表", "按对象批量导出结构、数据或完整备份。")}
|
||||
open={isBatchModalOpen}
|
||||
onCancel={() => setIsBatchModalOpen(false)}
|
||||
width={680}
|
||||
width={720}
|
||||
centered
|
||||
styles={{ content: modalPanelStyle, header: { background: 'transparent', borderBottom: 'none', paddingBottom: 10 }, body: { paddingTop: 8 }, footer: { background: 'transparent', borderTop: 'none', paddingTop: 12 } }}
|
||||
footer={
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', gap: 8, flexWrap: 'wrap' }}>
|
||||
<Button key="cancel" onClick={() => setIsBatchModalOpen(false)}>
|
||||
取消
|
||||
</Button>
|
||||
<Space size={8} wrap style={{ marginLeft: 'auto' }}>
|
||||
<Button
|
||||
key="clear"
|
||||
danger
|
||||
icon={<DeleteOutlined />}
|
||||
onClick={() => handleBatchClear()}
|
||||
disabled={checkedTableKeys.length === 0}
|
||||
>
|
||||
清空表
|
||||
</Button>
|
||||
<Button
|
||||
key="export-schema"
|
||||
icon={<ExportOutlined />}
|
||||
@@ -3531,7 +4014,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div style={{ marginBottom: 16 }}>
|
||||
<div style={{ ...modalSectionStyle, marginBottom: 16 }}>
|
||||
<div style={{ marginBottom: 8 }}>
|
||||
<label style={{ display: 'block', marginBottom: 4, fontWeight: 500 }}>选择连接:</label>
|
||||
<Select
|
||||
@@ -3563,10 +4046,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
))}
|
||||
</Select>
|
||||
</div>
|
||||
<div style={modalHintTextStyle}>先选择连接与数据库,再决定导出范围和目标对象。</div>
|
||||
</div>
|
||||
|
||||
{batchTables.length > 0 && (
|
||||
<div style={{ marginBottom: 16 }}>
|
||||
<div style={{ ...modalSectionStyle, marginBottom: 16 }}>
|
||||
<Space wrap size={8} style={{ width: '100%' }}>
|
||||
<Input
|
||||
allowClear
|
||||
@@ -3604,7 +4088,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
|
||||
{batchTables.length > 0 && (
|
||||
<>
|
||||
<div style={{ marginBottom: 16 }}>
|
||||
<div style={{ ...modalSectionStyle, marginBottom: 16 }}>
|
||||
<Space>
|
||||
<Button
|
||||
size="small"
|
||||
@@ -3632,7 +4116,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</span>
|
||||
</Space>
|
||||
</div>
|
||||
<div style={{ maxHeight: 400, overflow: 'auto', border: darkMode ? '1px solid #303030' : '1px solid #f0f0f0', borderRadius: 4, padding: 8 }}>
|
||||
<div style={modalScrollSectionStyle}>
|
||||
<Checkbox.Group
|
||||
value={checkedTableKeys}
|
||||
onChange={(values) => setCheckedTableKeys(values as string[])}
|
||||
@@ -3682,10 +4166,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</Modal>
|
||||
|
||||
<Modal
|
||||
title="批量操作库"
|
||||
title={renderSidebarModalTitle(<DatabaseOutlined />, "批量操作库", "按数据库批量导出结构,或生成结构加数据的备份。")}
|
||||
open={isBatchDbModalOpen}
|
||||
onCancel={() => setIsBatchDbModalOpen(false)}
|
||||
width={600}
|
||||
width={640}
|
||||
centered
|
||||
styles={{ content: modalPanelStyle, header: { background: 'transparent', borderBottom: 'none', paddingBottom: 10 }, body: { paddingTop: 8 }, footer: { background: 'transparent', borderTop: 'none', paddingTop: 12 } }}
|
||||
footer={[
|
||||
<Button key="cancel" onClick={() => setIsBatchDbModalOpen(false)}>
|
||||
取消
|
||||
@@ -3709,8 +4195,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</Button>
|
||||
]}
|
||||
>
|
||||
<div style={{ marginBottom: 16 }}>
|
||||
<label style={{ display: 'block', marginBottom: 4, fontWeight: 500 }}>选择连接:</label>
|
||||
<div style={{ ...modalSectionStyle, marginBottom: 16 }}>
|
||||
<label style={{ display: 'block', marginBottom: 4, fontWeight: 600, color: darkMode ? '#f5f7ff' : '#162033' }}>选择连接:</label>
|
||||
<Select
|
||||
value={selectedDbConnection}
|
||||
onChange={handleDbConnectionChange}
|
||||
@@ -3723,11 +4209,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
<div style={{ ...modalHintTextStyle, marginTop: 10 }}>连接选定后会加载当前连接下可批量导出的数据库列表。</div>
|
||||
</div>
|
||||
|
||||
{batchDatabases.length > 0 && (
|
||||
<>
|
||||
<div style={{ marginBottom: 16 }}>
|
||||
<div style={{ ...modalSectionStyle, marginBottom: 16 }}>
|
||||
<Space>
|
||||
<Button
|
||||
size="small"
|
||||
@@ -3752,7 +4239,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</span>
|
||||
</Space>
|
||||
</div>
|
||||
<div style={{ maxHeight: 400, overflow: 'auto', border: darkMode ? '1px solid #303030' : '1px solid #f0f0f0', borderRadius: 4, padding: 8 }}>
|
||||
<div style={modalScrollSectionStyle}>
|
||||
<Checkbox.Group
|
||||
value={checkedDbKeys}
|
||||
onChange={(values) => setCheckedDbKeys(values as string[])}
|
||||
@@ -3771,6 +4258,60 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
</>
|
||||
)}
|
||||
</Modal>
|
||||
|
||||
{/* SQL 文件流式执行进度 Modal */}
|
||||
<Modal
|
||||
title="运行外部SQL文件"
|
||||
open={sqlFileExecState.open}
|
||||
centered
|
||||
closable={sqlFileExecState.status !== 'running'}
|
||||
maskClosable={false}
|
||||
footer={sqlFileExecState.status === 'running' ? [
|
||||
<Button key="cancel" danger onClick={() => {
|
||||
CancelSQLFileExecution(sqlFileExecState.jobId);
|
||||
setSqlFileExecState(prev => ({ ...prev, status: 'cancelled' }));
|
||||
}}>
|
||||
取消执行
|
||||
</Button>
|
||||
] : [
|
||||
<Button key="close" type="primary" onClick={() => setSqlFileExecState(prev => ({ ...prev, open: false }))}>
|
||||
关闭
|
||||
</Button>
|
||||
]}
|
||||
onCancel={() => {
|
||||
if (sqlFileExecState.status !== 'running') {
|
||||
setSqlFileExecState(prev => ({ ...prev, open: false }));
|
||||
}
|
||||
}}
|
||||
styles={{ content: modalPanelStyle, header: { background: 'transparent', borderBottom: 'none' }, body: { paddingTop: 8 }, footer: { background: 'transparent', borderTop: 'none' } }}
|
||||
>
|
||||
<div style={{ marginBottom: 16 }}>
|
||||
<Progress
|
||||
percent={Math.round(sqlFileExecState.percent)}
|
||||
status={sqlFileExecState.status === 'error' ? 'exception' : sqlFileExecState.status === 'done' ? 'success' : 'active'}
|
||||
strokeColor={sqlFileExecState.status === 'cancelled' ? '#faad14' : undefined}
|
||||
/>
|
||||
</div>
|
||||
<div style={{ fontSize: 13, lineHeight: '22px', marginBottom: 8 }}>
|
||||
<div>文件大小:<strong>{sqlFileExecState.fileSizeMB} MB</strong></div>
|
||||
<div>状态:<strong>{
|
||||
sqlFileExecState.status === 'running' ? '执行中...' :
|
||||
sqlFileExecState.status === 'done' ? '✅ 完成' :
|
||||
sqlFileExecState.status === 'cancelled' ? '⚠️ 已取消' : '❌ 出错'
|
||||
}</strong></div>
|
||||
<div>已执行:<strong style={{ color: '#52c41a' }}>{sqlFileExecState.executed}</strong> 条 | 失败:<strong style={{ color: sqlFileExecState.failed > 0 ? '#ff4d4f' : undefined }}>{sqlFileExecState.failed}</strong> 条</div>
|
||||
</div>
|
||||
{sqlFileExecState.currentSQL && sqlFileExecState.status === 'running' && (
|
||||
<div style={{ fontSize: 12, color: 'rgba(128,128,128,0.8)', background: 'rgba(128,128,128,0.06)', borderRadius: 6, padding: '6px 10px', marginTop: 8, fontFamily: 'monospace', wordBreak: 'break-all', maxHeight: 60, overflow: 'hidden' }}>
|
||||
{sqlFileExecState.currentSQL}
|
||||
</div>
|
||||
)}
|
||||
{sqlFileExecState.resultMessage && sqlFileExecState.status !== 'running' && (
|
||||
<div style={{ fontSize: 12, marginTop: 12, maxHeight: 200, overflow: 'auto', whiteSpace: 'pre-wrap', background: 'rgba(128,128,128,0.06)', borderRadius: 6, padding: '8px 12px' }}>
|
||||
{sqlFileExecState.resultMessage}
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -144,12 +144,8 @@ const TabManager: React.FC = () => {
|
||||
const items = useMemo(() => tabs.map((tab, index) => {
|
||||
const connectionName = connections.find((conn) => conn.id === tab.connectionId)?.name;
|
||||
const displayTitle = buildTabDisplayTitle(tab, connectionName);
|
||||
const keepMountedWhenInactive = tab.type === 'query' || tab.type === 'redis-command';
|
||||
const shouldRenderContent = activeTabId === tab.id || keepMountedWhenInactive;
|
||||
let content;
|
||||
if (!shouldRenderContent) {
|
||||
content = null;
|
||||
} else if (tab.type === 'query') {
|
||||
if (tab.type === 'query') {
|
||||
content = <QueryEditor tab={tab} />;
|
||||
} else if (tab.type === 'table') {
|
||||
content = <DataViewer tab={tab} />;
|
||||
@@ -203,7 +199,7 @@ const TabManager: React.FC = () => {
|
||||
key: tab.id,
|
||||
children: content,
|
||||
};
|
||||
}), [tabs, connections, activeTabId, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
|
||||
}), [tabs, connections, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -297,6 +293,7 @@ const TabManager: React.FC = () => {
|
||||
<Tabs
|
||||
className="main-tabs"
|
||||
type="editable-card"
|
||||
destroyInactiveTabPane={false}
|
||||
onChange={(newActiveKey) => {
|
||||
if (Date.now() < suppressClickUntilRef.current) return;
|
||||
onChange(newActiveKey);
|
||||
|
||||
@@ -2491,7 +2491,7 @@ END;`;
|
||||
okText="应用"
|
||||
cancelText="取消"
|
||||
width={640}
|
||||
destroyOnClose
|
||||
destroyOnHidden
|
||||
>
|
||||
<Input.TextArea
|
||||
value={commentEditorValue}
|
||||
|
||||
39
frontend/src/components/dataGridLayout.test.ts
Normal file
39
frontend/src/components/dataGridLayout.test.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import { calculateTableBodyBottomPadding } from './dataGridLayout';
|
||||
|
||||
const assertEqual = (actual: unknown, expected: unknown, message: string) => {
|
||||
if (actual !== expected) {
|
||||
throw new Error(`${message}\nactual: ${String(actual)}\nexpected: ${String(expected)}`);
|
||||
}
|
||||
};
|
||||
|
||||
assertEqual(
|
||||
calculateTableBodyBottomPadding({
|
||||
hasHorizontalOverflow: false,
|
||||
floatingScrollbarHeight: 10,
|
||||
floatingScrollbarGap: 6,
|
||||
}),
|
||||
0,
|
||||
'无横向滚动条时不应增加底部间距'
|
||||
);
|
||||
|
||||
assertEqual(
|
||||
calculateTableBodyBottomPadding({
|
||||
hasHorizontalOverflow: true,
|
||||
floatingScrollbarHeight: 10,
|
||||
floatingScrollbarGap: 6,
|
||||
}),
|
||||
28,
|
||||
'默认悬浮滚动条应预留滚动条高度、间距和额外安全区'
|
||||
);
|
||||
|
||||
assertEqual(
|
||||
calculateTableBodyBottomPadding({
|
||||
hasHorizontalOverflow: true,
|
||||
floatingScrollbarHeight: 14,
|
||||
floatingScrollbarGap: 4,
|
||||
}),
|
||||
30,
|
||||
'较粗滚动条场景下应同步放大底部安全区'
|
||||
);
|
||||
|
||||
console.log('dataGridLayout tests passed');
|
||||
23
frontend/src/components/dataGridLayout.ts
Normal file
23
frontend/src/components/dataGridLayout.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
export interface TableBodyBottomPaddingOptions {
|
||||
hasHorizontalOverflow: boolean;
|
||||
floatingScrollbarHeight: number;
|
||||
floatingScrollbarGap: number;
|
||||
}
|
||||
|
||||
const MIN_SCROLLBAR_CLEARANCE = 8;
|
||||
const FLOATING_SCROLLBAR_VISUAL_EXTRA = 4;
|
||||
|
||||
export const calculateTableBodyBottomPadding = ({
|
||||
hasHorizontalOverflow,
|
||||
floatingScrollbarHeight,
|
||||
floatingScrollbarGap,
|
||||
}: TableBodyBottomPaddingOptions): number => {
|
||||
if (!hasHorizontalOverflow) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const safeScrollbarHeight = Math.max(0, Math.ceil(floatingScrollbarHeight));
|
||||
const safeScrollbarGap = Math.max(0, Math.ceil(floatingScrollbarGap));
|
||||
|
||||
return safeScrollbarHeight + FLOATING_SCROLLBAR_VISUAL_EXTRA + safeScrollbarGap + MIN_SCROLLBAR_CLEARANCE;
|
||||
};
|
||||
105
frontend/src/components/redisViewerTree.test.ts
Normal file
105
frontend/src/components/redisViewerTree.test.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import type { RedisKeyInfo } from '../types';
|
||||
import {
|
||||
applyRenamedRedisKeyState,
|
||||
applyTreeNodeCheck,
|
||||
buildCheckedTreeNodeState,
|
||||
buildRedisKeyTree,
|
||||
isGroupFullyChecked,
|
||||
} from './redisViewerTree';
|
||||
|
||||
const assert = (condition: unknown, message: string) => {
|
||||
if (!condition) {
|
||||
throw new Error(message);
|
||||
}
|
||||
};
|
||||
|
||||
const assertEqual = (actual: unknown, expected: unknown, message: string) => {
|
||||
const actualText = JSON.stringify(actual);
|
||||
const expectedText = JSON.stringify(expected);
|
||||
if (actualText !== expectedText) {
|
||||
throw new Error(`${message}\nactual: ${actualText}\nexpected: ${expectedText}`);
|
||||
}
|
||||
};
|
||||
|
||||
const sampleKeys: RedisKeyInfo[] = [
|
||||
{ key: 'app:user:1', type: 'string', ttl: -1 },
|
||||
{ key: 'app:user:2', type: 'string', ttl: -1 },
|
||||
{ key: 'app:order:1', type: 'hash', ttl: 120 },
|
||||
{ key: 'misc', type: 'set', ttl: -1 },
|
||||
];
|
||||
|
||||
const tree = buildRedisKeyTree(sampleKeys, true);
|
||||
const appGroup = tree.treeData.find((node) => node.key === 'group:app');
|
||||
const userGroup = appGroup?.children?.find((node) => node.key === 'group:app:user');
|
||||
|
||||
assert(appGroup, '应生成 group:app 节点');
|
||||
assert(userGroup, '应生成 group:app:user 节点');
|
||||
assertEqual(
|
||||
appGroup?.descendantRawKeys,
|
||||
['app:order:1', 'app:user:1', 'app:user:2'],
|
||||
'app 分组应收集全部后代 key'
|
||||
);
|
||||
|
||||
const selectedAfterGroupCheck = applyTreeNodeCheck([], appGroup!, true);
|
||||
assertEqual(
|
||||
selectedAfterGroupCheck,
|
||||
['app:order:1', 'app:user:1', 'app:user:2'],
|
||||
'勾选分组应递归选中全部后代 key'
|
||||
);
|
||||
|
||||
const checkedState = buildCheckedTreeNodeState(selectedAfterGroupCheck, tree);
|
||||
assertEqual(
|
||||
checkedState.checked,
|
||||
['key:app:order:1', 'group:app:order', 'key:app:user:1', 'key:app:user:2', 'group:app:user', 'group:app'],
|
||||
'全部后代已选中时,父分组和叶子都应进入 checked'
|
||||
);
|
||||
assertEqual(checkedState.halfChecked, [], '全部后代已选中时不应有 halfChecked');
|
||||
assertEqual(isGroupFullyChecked(appGroup!, selectedAfterGroupCheck), true, '全部后代已选中时,分组应视为 fully checked');
|
||||
|
||||
const selectedAfterGroupUncheck = applyTreeNodeCheck(selectedAfterGroupCheck, appGroup!, false);
|
||||
assertEqual(selectedAfterGroupUncheck, [], '取消勾选分组应移除全部后代 key');
|
||||
assertEqual(isGroupFullyChecked(appGroup!, selectedAfterGroupUncheck), false, '取消后分组不应再是 fully checked');
|
||||
|
||||
const partialState = buildCheckedTreeNodeState(['app:user:1'], tree);
|
||||
assertEqual(
|
||||
partialState.halfChecked,
|
||||
['group:app:user', 'group:app'],
|
||||
'仅部分后代选中时,相关分组应进入 halfChecked'
|
||||
);
|
||||
assertEqual(isGroupFullyChecked(appGroup!, ['app:user:1']), false, '部分选中时分组不应是 fully checked');
|
||||
|
||||
const renamedState = applyRenamedRedisKeyState(
|
||||
{
|
||||
keys: sampleKeys,
|
||||
selectedKey: 'app:user:2',
|
||||
selectedKeys: ['app:user:1', 'app:user:2', 'misc'],
|
||||
},
|
||||
'app:user:2',
|
||||
'app:user:200'
|
||||
);
|
||||
|
||||
assertEqual(
|
||||
renamedState.keys.map((item) => item.key),
|
||||
['app:user:1', 'app:user:200', 'app:order:1', 'misc'],
|
||||
'重命名后 keys 列表应替换旧 key'
|
||||
);
|
||||
assertEqual(renamedState.selectedKey, 'app:user:200', '当前详情选中的 key 应切换为新 key');
|
||||
assertEqual(
|
||||
renamedState.selectedKeys,
|
||||
['app:user:1', 'app:user:200', 'misc'],
|
||||
'批量选中集合中的旧 key 应映射为新 key'
|
||||
);
|
||||
|
||||
const unrelatedRenameState = applyRenamedRedisKeyState(
|
||||
{
|
||||
keys: sampleKeys,
|
||||
selectedKey: 'misc',
|
||||
selectedKeys: ['app:user:1'],
|
||||
},
|
||||
'app:order:1',
|
||||
'app:order:9'
|
||||
);
|
||||
assertEqual(unrelatedRenameState.selectedKey, 'misc', '非当前详情 key 的重命名不应影响 selectedKey');
|
||||
assertEqual(unrelatedRenameState.selectedKeys, ['app:user:1'], '非已勾选 key 的重命名不应污染选中集合');
|
||||
|
||||
console.log('redisViewerTree tests passed');
|
||||
260
frontend/src/components/redisViewerTree.ts
Normal file
260
frontend/src/components/redisViewerTree.ts
Normal file
@@ -0,0 +1,260 @@
|
||||
import type { DataNode } from 'antd/es/tree';
|
||||
import type { RedisKeyInfo } from '../types';
|
||||
|
||||
const KEY_GROUP_DELIMITER = ':';
|
||||
const EMPTY_SEGMENT_LABEL = '(empty)';
|
||||
|
||||
type RedisKeyTreeLeaf = {
|
||||
keyInfo: RedisKeyInfo;
|
||||
label: string;
|
||||
};
|
||||
|
||||
type RedisKeyTreeGroup = {
|
||||
name: string;
|
||||
path: string;
|
||||
children: Map<string, RedisKeyTreeGroup>;
|
||||
leaves: RedisKeyTreeLeaf[];
|
||||
leafCount: number;
|
||||
};
|
||||
|
||||
export type RedisTreeDataNode = DataNode & {
|
||||
nodeType: 'group' | 'leaf';
|
||||
groupName?: string;
|
||||
groupLeafCount?: number;
|
||||
leafLabel?: string;
|
||||
rawKey?: string;
|
||||
keyType?: string;
|
||||
ttl?: number;
|
||||
descendantRawKeys?: string[];
|
||||
};
|
||||
|
||||
export type RedisKeyTreeResult = {
|
||||
treeData: RedisTreeDataNode[];
|
||||
groupKeys: string[];
|
||||
};
|
||||
|
||||
export type RedisTreeCheckedState = {
|
||||
checked: string[];
|
||||
halfChecked: string[];
|
||||
};
|
||||
|
||||
export type RenamedRedisKeyStateInput = {
|
||||
keys: RedisKeyInfo[];
|
||||
selectedKey: string | null;
|
||||
selectedKeys: string[];
|
||||
};
|
||||
|
||||
export type RenamedRedisKeyStateResult = {
|
||||
keys: RedisKeyInfo[];
|
||||
selectedKey: string | null;
|
||||
selectedKeys: string[];
|
||||
};
|
||||
|
||||
const normalizeKeySegment = (segment: string): string => {
|
||||
return segment === '' ? EMPTY_SEGMENT_LABEL : segment;
|
||||
};
|
||||
|
||||
const createTreeGroup = (name: string, path: string): RedisKeyTreeGroup => {
|
||||
return { name, path, children: new Map(), leaves: [], leafCount: 0 };
|
||||
};
|
||||
|
||||
const calculateGroupLeafCount = (group: RedisKeyTreeGroup): number => {
|
||||
let count = group.leaves.length;
|
||||
group.children.forEach((child) => {
|
||||
count += calculateGroupLeafCount(child);
|
||||
});
|
||||
group.leafCount = count;
|
||||
return count;
|
||||
};
|
||||
|
||||
export const buildLeafNodeKey = (rawKey: string): string => `key:${rawKey}`;
|
||||
|
||||
export const parseRawKeyFromNodeKey = (nodeKey: React.Key): string | null => {
|
||||
const keyText = String(nodeKey);
|
||||
if (!keyText.startsWith('key:')) {
|
||||
return null;
|
||||
}
|
||||
return keyText.slice(4);
|
||||
};
|
||||
|
||||
export const buildRedisKeyTree = (
|
||||
keys: RedisKeyInfo[],
|
||||
sortLeafNodes: boolean
|
||||
): RedisKeyTreeResult => {
|
||||
const root = createTreeGroup('__root__', '__root__');
|
||||
|
||||
keys.forEach((keyInfo) => {
|
||||
const segments = keyInfo.key.split(KEY_GROUP_DELIMITER);
|
||||
if (segments.length <= 1) {
|
||||
root.leaves.push({ keyInfo, label: keyInfo.key });
|
||||
return;
|
||||
}
|
||||
|
||||
const groupSegments = segments.slice(0, -1);
|
||||
const leafLabel = normalizeKeySegment(segments[segments.length - 1]);
|
||||
let current = root;
|
||||
const pathParts: string[] = [];
|
||||
|
||||
groupSegments.forEach((segment) => {
|
||||
const normalized = normalizeKeySegment(segment);
|
||||
pathParts.push(normalized);
|
||||
const groupPath = pathParts.join(KEY_GROUP_DELIMITER);
|
||||
let child = current.children.get(normalized);
|
||||
if (!child) {
|
||||
child = createTreeGroup(normalized, groupPath);
|
||||
current.children.set(normalized, child);
|
||||
}
|
||||
current = child;
|
||||
});
|
||||
|
||||
current.leaves.push({ keyInfo, label: leafLabel });
|
||||
});
|
||||
|
||||
calculateGroupLeafCount(root);
|
||||
const groupKeys: string[] = [];
|
||||
|
||||
const toTreeNodes = (group: RedisKeyTreeGroup): RedisTreeDataNode[] => {
|
||||
const childGroups = Array.from(group.children.values()).sort((a, b) => a.name.localeCompare(b.name));
|
||||
const childLeaves = sortLeafNodes
|
||||
? [...group.leaves].sort((a, b) => a.keyInfo.key.localeCompare(b.keyInfo.key))
|
||||
: group.leaves;
|
||||
|
||||
const groupNodes: RedisTreeDataNode[] = childGroups.map((child) => {
|
||||
const children = toTreeNodes(child);
|
||||
const descendantRawKeys = children.flatMap((node) => {
|
||||
if (node.nodeType === 'leaf') {
|
||||
return node.rawKey ? [node.rawKey] : [];
|
||||
}
|
||||
return node.descendantRawKeys || [];
|
||||
});
|
||||
const groupNodeKey = `group:${child.path}`;
|
||||
groupKeys.push(groupNodeKey);
|
||||
return {
|
||||
key: groupNodeKey,
|
||||
title: child.name,
|
||||
nodeType: 'group',
|
||||
groupName: child.name,
|
||||
groupLeafCount: child.leafCount,
|
||||
selectable: false,
|
||||
descendantRawKeys,
|
||||
children,
|
||||
};
|
||||
});
|
||||
|
||||
const leafNodes: RedisTreeDataNode[] = childLeaves.map((leaf) => {
|
||||
return {
|
||||
key: buildLeafNodeKey(leaf.keyInfo.key),
|
||||
isLeaf: true,
|
||||
title: leaf.label,
|
||||
nodeType: 'leaf',
|
||||
leafLabel: leaf.label,
|
||||
rawKey: leaf.keyInfo.key,
|
||||
keyType: leaf.keyInfo.type,
|
||||
ttl: leaf.keyInfo.ttl,
|
||||
};
|
||||
});
|
||||
|
||||
return [...groupNodes, ...leafNodes];
|
||||
};
|
||||
|
||||
return {
|
||||
treeData: toTreeNodes(root),
|
||||
groupKeys,
|
||||
};
|
||||
};
|
||||
|
||||
export const applyTreeNodeCheck = (
|
||||
selectedKeys: string[],
|
||||
node: RedisTreeDataNode,
|
||||
checked: boolean
|
||||
): string[] => {
|
||||
if (node.nodeType === 'leaf') {
|
||||
if (!node.rawKey) {
|
||||
return selectedKeys;
|
||||
}
|
||||
if (checked) {
|
||||
return Array.from(new Set([...selectedKeys, node.rawKey]));
|
||||
}
|
||||
return selectedKeys.filter((item) => item !== node.rawKey);
|
||||
}
|
||||
|
||||
const descendantRawKeys = node.descendantRawKeys || [];
|
||||
if (descendantRawKeys.length === 0) {
|
||||
return selectedKeys;
|
||||
}
|
||||
if (checked) {
|
||||
return Array.from(new Set([...selectedKeys, ...descendantRawKeys]));
|
||||
}
|
||||
const removeSet = new Set(descendantRawKeys);
|
||||
return selectedKeys.filter((item) => !removeSet.has(item));
|
||||
};
|
||||
|
||||
const walkGroupStates = (
|
||||
nodes: RedisTreeDataNode[],
|
||||
selectedKeySet: Set<string>,
|
||||
checked: string[],
|
||||
halfChecked: string[]
|
||||
) => {
|
||||
nodes.forEach((node) => {
|
||||
if (node.nodeType === 'leaf') {
|
||||
if (node.rawKey && selectedKeySet.has(node.rawKey)) {
|
||||
checked.push(String(node.key));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
walkGroupStates((node.children || []) as RedisTreeDataNode[], selectedKeySet, checked, halfChecked);
|
||||
const descendantRawKeys = node.descendantRawKeys || [];
|
||||
if (descendantRawKeys.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedCount = descendantRawKeys.filter((rawKey) => selectedKeySet.has(rawKey)).length;
|
||||
if (selectedCount === descendantRawKeys.length) {
|
||||
checked.push(String(node.key));
|
||||
return;
|
||||
}
|
||||
if (selectedCount > 0) {
|
||||
halfChecked.push(String(node.key));
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
export const buildCheckedTreeNodeState = (
|
||||
selectedKeys: string[],
|
||||
keyTree: RedisKeyTreeResult
|
||||
): RedisTreeCheckedState => {
|
||||
const selectedKeySet = new Set(selectedKeys);
|
||||
const checked: string[] = [];
|
||||
const halfChecked: string[] = [];
|
||||
|
||||
walkGroupStates(keyTree.treeData, selectedKeySet, checked, halfChecked);
|
||||
return { checked, halfChecked };
|
||||
};
|
||||
|
||||
export const isGroupFullyChecked = (
|
||||
node: RedisTreeDataNode,
|
||||
selectedKeys: string[]
|
||||
): boolean => {
|
||||
if (node.nodeType !== 'group') {
|
||||
return false;
|
||||
}
|
||||
const descendantRawKeys = node.descendantRawKeys || [];
|
||||
if (descendantRawKeys.length === 0) {
|
||||
return false;
|
||||
}
|
||||
const selectedKeySet = new Set(selectedKeys);
|
||||
return descendantRawKeys.every((rawKey) => selectedKeySet.has(rawKey));
|
||||
};
|
||||
|
||||
export const applyRenamedRedisKeyState = (
|
||||
state: RenamedRedisKeyStateInput,
|
||||
oldKey: string,
|
||||
newKey: string
|
||||
): RenamedRedisKeyStateResult => {
|
||||
return {
|
||||
keys: state.keys.map((item) => (item.key === oldKey ? { ...item, key: newKey } : item)),
|
||||
selectedKey: state.selectedKey === oldKey ? newKey : state.selectedKey,
|
||||
selectedKeys: state.selectedKeys.map((item) => (item === oldKey ? newKey : item)),
|
||||
};
|
||||
};
|
||||
50
frontend/src/components/redisViewerWorkbenchTheme.test.ts
Normal file
50
frontend/src/components/redisViewerWorkbenchTheme.test.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { buildRedisWorkbenchTheme } from './redisViewerWorkbenchTheme';
|
||||
|
||||
const assertEqual = (actual: unknown, expected: unknown, message: string) => {
|
||||
if (actual !== expected) {
|
||||
throw new Error(`${message}\nactual: ${String(actual)}\nexpected: ${String(expected)}`);
|
||||
}
|
||||
};
|
||||
|
||||
const assertNotEqual = (actual: unknown, expected: unknown, message: string) => {
|
||||
if (actual === expected) {
|
||||
throw new Error(`${message}\nactual: ${String(actual)}\nnotExpected: ${String(expected)}`);
|
||||
}
|
||||
};
|
||||
|
||||
const assertMatch = (value: string, pattern: RegExp, message: string) => {
|
||||
if (!pattern.test(value)) {
|
||||
throw new Error(`${message}\nactual: ${value}\npattern: ${String(pattern)}`);
|
||||
}
|
||||
};
|
||||
|
||||
const darkTheme = buildRedisWorkbenchTheme({
|
||||
darkMode: true,
|
||||
opacity: 0.72,
|
||||
blur: 14,
|
||||
});
|
||||
|
||||
assertEqual(darkTheme.isDark, true, 'dark 主题标记应为 true');
|
||||
assertMatch(darkTheme.panelBg, /^rgba\(/, 'dark 主题面板背景应为 rgba');
|
||||
assertMatch(darkTheme.toolbarPrimaryBg, /^linear-gradient\(/, '工具栏主按钮应使用渐变背景');
|
||||
assertNotEqual(darkTheme.actionDangerBg, darkTheme.actionSecondaryBg, '危险态按钮背景不应与普通按钮相同');
|
||||
assertNotEqual(darkTheme.treeSelectedBg, darkTheme.treeHoverBg, '树节点选中态与悬浮态不应相同');
|
||||
assertMatch(darkTheme.appBg, /rgba\(15, 15, 17,/, 'dark 背景应保持中性黑基底');
|
||||
assertMatch(darkTheme.panelBg, /rgba\(24, 24, 28,/, 'dark 面板背景应保持中性黑灰');
|
||||
assertMatch(darkTheme.panelBgStrong, /rgba\(31, 31, 36,/, 'dark 强面板背景应保持中性黑灰');
|
||||
assertEqual(darkTheme.backdropFilter, 'blur(14px)', 'blur 参数应映射为 backdropFilter');
|
||||
|
||||
const lightTheme = buildRedisWorkbenchTheme({
|
||||
darkMode: false,
|
||||
opacity: 1,
|
||||
blur: 0,
|
||||
});
|
||||
|
||||
assertEqual(lightTheme.isDark, false, 'light 主题标记应为 false');
|
||||
assertMatch(lightTheme.panelBg, /^rgba\(/, 'light 主题面板背景应为 rgba');
|
||||
assertMatch(lightTheme.contentEmptyBg, /^linear-gradient\(/, 'light 空状态背景应为渐变');
|
||||
assertNotEqual(lightTheme.textPrimary, lightTheme.textSecondary, '主次文本颜色应区分');
|
||||
assertNotEqual(lightTheme.statusTagBg, lightTheme.statusTagMutedBg, '状态 tag 应区分普通与弱化样式');
|
||||
assertEqual(lightTheme.backdropFilter, 'none', 'blur=0 时 backdropFilter 应为 none');
|
||||
|
||||
console.log('redisViewerWorkbenchTheme tests passed');
|
||||
129
frontend/src/components/redisViewerWorkbenchTheme.ts
Normal file
129
frontend/src/components/redisViewerWorkbenchTheme.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
type RedisWorkbenchThemeInput = {
|
||||
darkMode: boolean;
|
||||
opacity: number;
|
||||
blur: number;
|
||||
};
|
||||
|
||||
type RedisWorkbenchTheme = {
|
||||
isDark: boolean;
|
||||
appBg: string;
|
||||
panelBg: string;
|
||||
panelBgStrong: string;
|
||||
panelBgSubtle: string;
|
||||
panelBorder: string;
|
||||
panelInset: string;
|
||||
toolbarPrimaryBg: string;
|
||||
contentEmptyBg: string;
|
||||
textPrimary: string;
|
||||
textSecondary: string;
|
||||
textMuted: string;
|
||||
accent: string;
|
||||
accentSoft: string;
|
||||
accentBorder: string;
|
||||
actionSecondaryBg: string;
|
||||
actionSecondaryBorder: string;
|
||||
actionDangerBg: string;
|
||||
actionDangerBorder: string;
|
||||
actionDangerText: string;
|
||||
statusTagBg: string;
|
||||
statusTagBorder: string;
|
||||
statusTagMutedBg: string;
|
||||
statusTagMutedBorder: string;
|
||||
treeHoverBg: string;
|
||||
treeSelectedBg: string;
|
||||
treeSelectedBorder: string;
|
||||
divider: string;
|
||||
shadow: string;
|
||||
backdropFilter: string;
|
||||
};
|
||||
|
||||
const clamp = (value: number, min: number, max: number) => Math.min(max, Math.max(min, value));
|
||||
|
||||
export const buildRedisWorkbenchTheme = ({
|
||||
darkMode,
|
||||
opacity,
|
||||
blur,
|
||||
}: RedisWorkbenchThemeInput): RedisWorkbenchTheme => {
|
||||
const normalizedOpacity = clamp(opacity, 0.1, 1);
|
||||
const normalizedBlur = Math.max(0, Math.round(blur));
|
||||
const isTranslucent = normalizedOpacity < 0.999 || normalizedBlur > 0;
|
||||
|
||||
if (darkMode) {
|
||||
const appTopAlpha = isTranslucent ? Math.max(0.08, Math.min(0.22, normalizedOpacity * 0.16)) : 0.92;
|
||||
const appBottomAlpha = isTranslucent ? Math.max(0.12, Math.min(0.28, normalizedOpacity * 0.22)) : 0.96;
|
||||
const panelAlpha = isTranslucent ? Math.max(0.06, Math.min(0.16, normalizedOpacity * 0.1)) : 0.34;
|
||||
const strongAlpha = isTranslucent ? Math.max(0.1, Math.min(0.22, normalizedOpacity * 0.16)) : 0.42;
|
||||
const subtleAlpha = isTranslucent ? Math.max(0.03, Math.min(0.08, normalizedOpacity * 0.05)) : 0.08;
|
||||
return {
|
||||
isDark: true,
|
||||
appBg: `linear-gradient(180deg, rgba(15, 15, 17, ${appTopAlpha}) 0%, rgba(11, 11, 13, ${appBottomAlpha}) 100%)`,
|
||||
panelBg: `rgba(24, 24, 28, ${panelAlpha})`,
|
||||
panelBgStrong: `rgba(31, 31, 36, ${strongAlpha})`,
|
||||
panelBgSubtle: `rgba(255, 255, 255, ${subtleAlpha})`,
|
||||
panelBorder: `1px solid rgba(255, 255, 255, ${isTranslucent ? Math.max(0.12, Math.min(0.24, normalizedOpacity * 0.2)) : 0.08})`,
|
||||
panelInset: `inset 0 1px 0 rgba(255,255,255,${isTranslucent ? Math.max(0.05, Math.min(0.12, normalizedOpacity * 0.1)) : 0.04})`,
|
||||
toolbarPrimaryBg: `linear-gradient(135deg, rgba(246,196,83,0.22) 0%, rgba(246,196,83,0.12) 100%)`,
|
||||
contentEmptyBg: `linear-gradient(180deg, rgba(255,255,255,0.03) 0%, rgba(255,255,255,0.015) 100%)`,
|
||||
textPrimary: 'rgba(245, 247, 251, 0.96)',
|
||||
textSecondary: 'rgba(218, 224, 235, 0.82)',
|
||||
textMuted: 'rgba(168, 177, 194, 0.72)',
|
||||
accent: '#f6c453',
|
||||
accentSoft: 'rgba(246, 196, 83, 0.18)',
|
||||
accentBorder: 'rgba(246, 196, 83, 0.3)',
|
||||
actionSecondaryBg: 'rgba(255, 255, 255, 0.04)',
|
||||
actionSecondaryBorder: 'rgba(255, 255, 255, 0.09)',
|
||||
actionDangerBg: 'rgba(255, 95, 95, 0.12)',
|
||||
actionDangerBorder: 'rgba(255, 95, 95, 0.28)',
|
||||
actionDangerText: '#ff8f8f',
|
||||
statusTagBg: 'rgba(25, 106, 255, 0.16)',
|
||||
statusTagBorder: 'rgba(25, 106, 255, 0.28)',
|
||||
statusTagMutedBg: 'rgba(255, 255, 255, 0.04)',
|
||||
statusTagMutedBorder: 'rgba(255, 255, 255, 0.08)',
|
||||
treeHoverBg: 'rgba(255, 255, 255, 0.045)',
|
||||
treeSelectedBg: 'linear-gradient(90deg, rgba(246,196,83,0.2) 0%, rgba(246,196,83,0.08) 100%)',
|
||||
treeSelectedBorder: 'rgba(246, 196, 83, 0.24)',
|
||||
divider: 'rgba(255, 255, 255, 0.07)',
|
||||
shadow: '0 20px 48px rgba(0, 0, 0, 0.26)',
|
||||
backdropFilter: normalizedBlur > 0 ? `blur(${normalizedBlur}px)` : 'none',
|
||||
};
|
||||
}
|
||||
|
||||
const appTopAlpha = isTranslucent ? Math.max(0.16, Math.min(0.36, normalizedOpacity * 0.24)) : 0.98;
|
||||
const appBottomAlpha = isTranslucent ? Math.max(0.22, Math.min(0.44, normalizedOpacity * 0.32)) : 0.96;
|
||||
const panelAlpha = isTranslucent ? Math.max(0.18, Math.min(0.4, normalizedOpacity * 0.26)) : 0.94;
|
||||
const strongAlpha = isTranslucent ? Math.max(0.26, Math.min(0.52, normalizedOpacity * 0.34)) : 0.98;
|
||||
return {
|
||||
isDark: false,
|
||||
appBg: `linear-gradient(180deg, rgba(248, 250, 252, ${appTopAlpha}) 0%, rgba(242, 245, 248, ${appBottomAlpha}) 100%)`,
|
||||
panelBg: `rgba(255, 255, 255, ${panelAlpha})`,
|
||||
panelBgStrong: `rgba(255, 255, 255, ${strongAlpha})`,
|
||||
panelBgSubtle: 'rgba(15, 23, 42, 0.03)',
|
||||
panelBorder: `1px solid rgba(15, 23, 42, ${isTranslucent ? Math.max(0.1, Math.min(0.18, normalizedOpacity * 0.12)) : 0.08})`,
|
||||
panelInset: `inset 0 1px 0 rgba(255,255,255,${isTranslucent ? 0.38 : 0.72})`,
|
||||
toolbarPrimaryBg: 'linear-gradient(135deg, rgba(22,119,255,0.12) 0%, rgba(22,119,255,0.06) 100%)',
|
||||
contentEmptyBg: 'linear-gradient(180deg, rgba(15,23,42,0.02) 0%, rgba(15,23,42,0.01) 100%)',
|
||||
textPrimary: 'rgba(15, 23, 42, 0.92)',
|
||||
textSecondary: 'rgba(51, 65, 85, 0.82)',
|
||||
textMuted: 'rgba(100, 116, 139, 0.76)',
|
||||
accent: '#1677ff',
|
||||
accentSoft: 'rgba(22, 119, 255, 0.12)',
|
||||
accentBorder: 'rgba(22, 119, 255, 0.22)',
|
||||
actionSecondaryBg: 'rgba(255, 255, 255, 0.72)',
|
||||
actionSecondaryBorder: 'rgba(15, 23, 42, 0.08)',
|
||||
actionDangerBg: 'rgba(255, 77, 79, 0.08)',
|
||||
actionDangerBorder: 'rgba(255, 77, 79, 0.24)',
|
||||
actionDangerText: '#cf1322',
|
||||
statusTagBg: 'rgba(22, 119, 255, 0.1)',
|
||||
statusTagBorder: 'rgba(22, 119, 255, 0.16)',
|
||||
statusTagMutedBg: 'rgba(15, 23, 42, 0.04)',
|
||||
statusTagMutedBorder: 'rgba(15, 23, 42, 0.08)',
|
||||
treeHoverBg: 'rgba(15, 23, 42, 0.035)',
|
||||
treeSelectedBg: 'linear-gradient(90deg, rgba(22,119,255,0.12) 0%, rgba(22,119,255,0.05) 100%)',
|
||||
treeSelectedBorder: 'rgba(22, 119, 255, 0.18)',
|
||||
divider: 'rgba(15, 23, 42, 0.08)',
|
||||
shadow: '0 22px 52px rgba(15, 23, 42, 0.08)',
|
||||
backdropFilter: normalizedBlur > 0 ? `blur(${normalizedBlur}px)` : 'none',
|
||||
};
|
||||
};
|
||||
|
||||
export type { RedisWorkbenchTheme, RedisWorkbenchThemeInput };
|
||||
@@ -9,6 +9,36 @@ import { loader } from '@monaco-editor/react'
|
||||
import * as monaco from 'monaco-editor'
|
||||
loader.config({ monaco })
|
||||
|
||||
if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
(window as any).go = {
|
||||
app: {
|
||||
App: {
|
||||
CheckUpdate: async () => ({ success: false }),
|
||||
DownloadUpdate: async () => ({ success: false }),
|
||||
GetSavedConnections: async () => [],
|
||||
SaveConnection: async () => null,
|
||||
DeleteConnection: async () => null,
|
||||
OpenConnection: async () => null,
|
||||
CloseConnection: async () => null,
|
||||
GetDatabases: async () => [],
|
||||
GetTables: async () => [],
|
||||
GetTableData: async () => ({ columns: [], rows: [], total: 0 }),
|
||||
GetTableColumns: async () => [],
|
||||
ExecuteQuery: async () => ({ columns: [], rows: [], time: 0 }),
|
||||
GetSavedQueries: async () => [],
|
||||
SaveQuery: async () => null,
|
||||
DeleteQuery: async () => null,
|
||||
GetAppInfo: async () => ({}),
|
||||
CheckForUpdates: async () => ({ success: false }),
|
||||
OpenDownloadedUpdateDirectory: async () => ({ success: false }),
|
||||
InstallUpdateAndRestart: async () => ({ success: false }),
|
||||
ImportConfigFile: async () => ({ success: false }),
|
||||
ExportData: async () => ({ success: false }),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// 全局注册透明主题,避免每个 Editor 组件 beforeMount 中重复定义
|
||||
monaco.editor.defineTheme('transparent-dark', {
|
||||
base: 'vs-dark', inherit: true, rules: [],
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
sanitizeShortcutOptions,
|
||||
} from './utils/shortcuts';
|
||||
|
||||
const DEFAULT_APPEARANCE = { opacity: 1.0, blur: 0 };
|
||||
const DEFAULT_APPEARANCE = { enabled: true, opacity: 1.0, blur: 0 };
|
||||
const DEFAULT_UI_SCALE = 1.0;
|
||||
const MIN_UI_SCALE = 0.8;
|
||||
const MAX_UI_SCALE = 1.25;
|
||||
@@ -25,7 +25,7 @@ const MAX_HOST_ENTRY_LENGTH = 512;
|
||||
const MAX_HOST_ENTRIES = 64;
|
||||
const DEFAULT_TIMEOUT_SECONDS = 30;
|
||||
const MAX_TIMEOUT_SECONDS = 3600;
|
||||
const PERSIST_VERSION = 5;
|
||||
const PERSIST_VERSION = 6;
|
||||
const DEFAULT_CONNECTION_TYPE = 'mysql';
|
||||
const DEFAULT_GLOBAL_PROXY: GlobalProxyConfig = {
|
||||
enabled: false,
|
||||
@@ -231,6 +231,18 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
user: toTrimmedString(proxyRaw.user),
|
||||
password: toTrimmedString(proxyRaw.password),
|
||||
};
|
||||
const httpTunnelRaw = (raw.httpTunnel && typeof raw.httpTunnel === 'object')
|
||||
? raw.httpTunnel as Record<string, unknown>
|
||||
: ((raw.HTTPTunnel && typeof raw.HTTPTunnel === 'object') ? raw.HTTPTunnel as Record<string, unknown> : {});
|
||||
const httpTunnel = {
|
||||
host: toTrimmedString(httpTunnelRaw.host ?? raw.httpTunnelHost),
|
||||
port: normalizePort(httpTunnelRaw.port ?? raw.httpTunnelPort, 8080),
|
||||
user: toTrimmedString(httpTunnelRaw.user ?? raw.httpTunnelUser),
|
||||
password: toTrimmedString(httpTunnelRaw.password ?? raw.httpTunnelPassword),
|
||||
};
|
||||
const supportsNetworkTunnel = type !== 'sqlite' && type !== 'duckdb';
|
||||
const useHttpTunnel = supportsNetworkTunnel && (raw.useHttpTunnel === true || raw.UseHTTPTunnel === true);
|
||||
const useProxy = supportsNetworkTunnel && !!raw.useProxy && !useHttpTunnel;
|
||||
|
||||
const safeConfig: ConnectionConfig & Record<string, unknown> = {
|
||||
...raw,
|
||||
@@ -247,8 +259,10 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : '',
|
||||
useSSH: !!raw.useSSH,
|
||||
ssh,
|
||||
useProxy: !!raw.useProxy,
|
||||
useProxy,
|
||||
proxy,
|
||||
useHttpTunnel,
|
||||
httpTunnel,
|
||||
uri: toTrimmedString(raw.uri).slice(0, MAX_URI_LENGTH),
|
||||
hosts: sanitizeAddressList(raw.hosts),
|
||||
topology: raw.topology === 'replica' ? 'replica' : (raw.topology === 'cluster' ? 'cluster' : 'single'),
|
||||
@@ -391,7 +405,7 @@ interface AppState {
|
||||
activeContext: { connectionId: string; dbName: string } | null;
|
||||
savedQueries: SavedQuery[];
|
||||
theme: 'light' | 'dark';
|
||||
appearance: { opacity: number; blur: number };
|
||||
appearance: { enabled: boolean; opacity: number; blur: number };
|
||||
uiScale: number;
|
||||
fontSize: number;
|
||||
startupFullscreen: boolean;
|
||||
@@ -402,6 +416,10 @@ interface AppState {
|
||||
sqlLogs: SqlLog[];
|
||||
tableAccessCount: Record<string, number>;
|
||||
tableSortPreference: Record<string, 'name' | 'frequency'>;
|
||||
tableColumnOrders: Record<string, string[]>;
|
||||
enableColumnOrderMemory: boolean;
|
||||
tableHiddenColumns: Record<string, string[]>;
|
||||
enableHiddenColumnMemory: boolean;
|
||||
|
||||
addConnection: (conn: SavedConnection) => void;
|
||||
updateConnection: (conn: SavedConnection) => void;
|
||||
@@ -429,7 +447,7 @@ interface AppState {
|
||||
deleteQuery: (id: string) => void;
|
||||
|
||||
setTheme: (theme: 'light' | 'dark') => void;
|
||||
setAppearance: (appearance: Partial<{ opacity: number; blur: number }>) => void;
|
||||
setAppearance: (appearance: Partial<{ enabled: boolean; opacity: number; blur: number }>) => void;
|
||||
setUiScale: (scale: number) => void;
|
||||
setFontSize: (size: number) => void;
|
||||
setStartupFullscreen: (enabled: boolean) => void;
|
||||
@@ -444,6 +462,13 @@ interface AppState {
|
||||
|
||||
recordTableAccess: (connectionId: string, dbName: string, tableName: string) => void;
|
||||
setTableSortPreference: (connectionId: string, dbName: string, sortBy: 'name' | 'frequency') => void;
|
||||
setTableColumnOrder: (connectionId: string, dbName: string, tableName: string, order: string[]) => void;
|
||||
setEnableColumnOrderMemory: (enabled: boolean) => void;
|
||||
clearTableColumnOrder: (connectionId: string, dbName: string, tableName: string) => void;
|
||||
|
||||
setTableHiddenColumns: (connectionId: string, dbName: string, tableName: string, hiddenColumns: string[]) => void;
|
||||
setEnableHiddenColumnMemory: (enabled: boolean) => void;
|
||||
clearTableHiddenColumns: (connectionId: string, dbName: string, tableName: string) => void;
|
||||
}
|
||||
|
||||
const sanitizeSavedQueries = (value: unknown): SavedQuery[] => {
|
||||
@@ -507,14 +532,37 @@ const sanitizeTableSortPreference = (value: unknown): Record<string, 'name' | 'f
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeTableColumnOrders = (value: unknown): Record<string, string[]> => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const result: Record<string, string[]> = {};
|
||||
Object.entries(raw).forEach(([key, orderArray]) => {
|
||||
if (Array.isArray(orderArray)) {
|
||||
result[key] = orderArray.map(col => String(col));
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeTableHiddenColumns = (value: unknown): Record<string, string[]> => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const result: Record<string, string[]> = {};
|
||||
Object.entries(raw).forEach(([key, hiddenArray]) => {
|
||||
if (Array.isArray(hiddenArray)) {
|
||||
result[key] = hiddenArray.map(col => String(col));
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeAppearance = (
|
||||
appearance: Partial<{ opacity: number; blur: number }> | undefined,
|
||||
appearance: Partial<{ enabled: boolean; opacity: number; blur: number }> | undefined,
|
||||
version: number
|
||||
): { opacity: number; blur: number } => {
|
||||
): { enabled: boolean; opacity: number; blur: number } => {
|
||||
if (!appearance || typeof appearance !== 'object') {
|
||||
return { ...DEFAULT_APPEARANCE };
|
||||
}
|
||||
const nextAppearance = {
|
||||
enabled: typeof appearance.enabled === 'boolean' ? appearance.enabled : DEFAULT_APPEARANCE.enabled,
|
||||
opacity: typeof appearance.opacity === 'number' ? appearance.opacity : DEFAULT_APPEARANCE.opacity,
|
||||
blur: typeof appearance.blur === 'number' ? appearance.blur : DEFAULT_APPEARANCE.blur,
|
||||
};
|
||||
@@ -583,6 +631,10 @@ export const useStore = create<AppState>()(
|
||||
sqlLogs: [],
|
||||
tableAccessCount: {},
|
||||
tableSortPreference: {},
|
||||
tableColumnOrders: {},
|
||||
enableColumnOrderMemory: true,
|
||||
tableHiddenColumns: {},
|
||||
enableHiddenColumnMemory: true,
|
||||
|
||||
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
|
||||
updateConnection: (conn) => set((state) => ({
|
||||
@@ -785,6 +837,44 @@ export const useStore = create<AppState>()(
|
||||
}
|
||||
};
|
||||
}),
|
||||
|
||||
setTableColumnOrder: (connectionId, dbName, tableName, order) => set((state) => {
|
||||
const key = `${connectionId}-${dbName}-${tableName}`;
|
||||
return {
|
||||
tableColumnOrders: {
|
||||
...state.tableColumnOrders,
|
||||
[key]: order
|
||||
}
|
||||
};
|
||||
}),
|
||||
|
||||
clearTableColumnOrder: (connectionId, dbName, tableName) => set((state) => {
|
||||
const key = `${connectionId}-${dbName}-${tableName}`;
|
||||
const newOrders = { ...state.tableColumnOrders };
|
||||
delete newOrders[key];
|
||||
return { tableColumnOrders: newOrders };
|
||||
}),
|
||||
|
||||
setEnableColumnOrderMemory: (enabled) => set({ enableColumnOrderMemory: !!enabled }),
|
||||
|
||||
setTableHiddenColumns: (connectionId, dbName, tableName, hiddenColumns) => set((state) => {
|
||||
const key = `${connectionId}-${dbName}-${tableName}`;
|
||||
return {
|
||||
tableHiddenColumns: {
|
||||
...state.tableHiddenColumns,
|
||||
[key]: hiddenColumns
|
||||
}
|
||||
};
|
||||
}),
|
||||
|
||||
clearTableHiddenColumns: (connectionId, dbName, tableName) => set((state) => {
|
||||
const key = `${connectionId}-${dbName}-${tableName}`;
|
||||
const newHidden = { ...state.tableHiddenColumns };
|
||||
delete newHidden[key];
|
||||
return { tableHiddenColumns: newHidden };
|
||||
}),
|
||||
|
||||
setEnableHiddenColumnMemory: (enabled) => set({ enableHiddenColumnMemory: !!enabled }),
|
||||
}),
|
||||
{
|
||||
name: 'lite-db-storage', // name of the item in the storage (must be unique)
|
||||
@@ -810,6 +900,13 @@ export const useStore = create<AppState>()(
|
||||
nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions);
|
||||
nextState.tableAccessCount = sanitizeTableAccessCount(state.tableAccessCount);
|
||||
nextState.tableSortPreference = sanitizeTableSortPreference(state.tableSortPreference);
|
||||
// 新增的列排序记忆状态不需要做版本特殊兼容,直接做基本的类型保护
|
||||
const safeOrders = sanitizeTableColumnOrders(state.tableColumnOrders);
|
||||
nextState.tableColumnOrders = safeOrders;
|
||||
nextState.enableColumnOrderMemory = state.enableColumnOrderMemory !== false;
|
||||
const safeHidden = sanitizeTableHiddenColumns(state.tableHiddenColumns);
|
||||
nextState.tableHiddenColumns = safeHidden;
|
||||
nextState.enableHiddenColumnMemory = state.enableHiddenColumnMemory !== false;
|
||||
return nextState as AppState;
|
||||
},
|
||||
merge: (persistedState, currentState) => {
|
||||
@@ -826,11 +923,16 @@ export const useStore = create<AppState>()(
|
||||
fontSize: sanitizeFontSize(state.fontSize),
|
||||
startupFullscreen: sanitizeStartupFullscreen(state.startupFullscreen),
|
||||
globalProxy: sanitizeGlobalProxy(state.globalProxy),
|
||||
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
|
||||
tableColumnOrders: sanitizeTableColumnOrders(state.tableColumnOrders),
|
||||
enableColumnOrderMemory: state.enableColumnOrderMemory !== false,
|
||||
tableHiddenColumns: sanitizeTableHiddenColumns(state.tableHiddenColumns),
|
||||
enableHiddenColumnMemory: state.enableHiddenColumnMemory !== false,
|
||||
|
||||
sqlFormatOptions: sanitizeSqlFormatOptions(state.sqlFormatOptions),
|
||||
queryOptions: sanitizeQueryOptions(state.queryOptions),
|
||||
shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions),
|
||||
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
|
||||
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
|
||||
};
|
||||
},
|
||||
partialize: (state) => ({
|
||||
@@ -847,7 +949,11 @@ export const useStore = create<AppState>()(
|
||||
queryOptions: state.queryOptions,
|
||||
shortcutOptions: state.shortcutOptions,
|
||||
tableAccessCount: state.tableAccessCount,
|
||||
tableSortPreference: state.tableSortPreference
|
||||
tableSortPreference: state.tableSortPreference,
|
||||
tableColumnOrders: state.tableColumnOrders,
|
||||
enableColumnOrderMemory: state.enableColumnOrderMemory,
|
||||
tableHiddenColumns: state.tableHiddenColumns,
|
||||
enableHiddenColumnMemory: state.enableHiddenColumnMemory
|
||||
}), // Don't persist logs
|
||||
}
|
||||
)
|
||||
|
||||
@@ -14,6 +14,13 @@ export interface ProxyConfig {
|
||||
password?: string;
|
||||
}
|
||||
|
||||
export interface HTTPTunnelConfig {
|
||||
host: string;
|
||||
port: number;
|
||||
user?: string;
|
||||
password?: string;
|
||||
}
|
||||
|
||||
export interface ConnectionConfig {
|
||||
type: string;
|
||||
host: string;
|
||||
@@ -30,6 +37,8 @@ export interface ConnectionConfig {
|
||||
ssh?: SSHConfig;
|
||||
useProxy?: boolean;
|
||||
proxy?: ProxyConfig;
|
||||
useHttpTunnel?: boolean;
|
||||
httpTunnel?: HTTPTunnelConfig;
|
||||
driver?: string;
|
||||
dsn?: string;
|
||||
timeout?: number;
|
||||
@@ -119,6 +128,7 @@ export interface TabData {
|
||||
viewName?: string; // View name for view definition tabs
|
||||
routineName?: string; // Routine name for function/procedure definition tabs
|
||||
routineType?: string; // 'FUNCTION' or 'PROCEDURE'
|
||||
savedQueryId?: string; // Saved query identity for quick-save behavior
|
||||
}
|
||||
|
||||
export interface DatabaseNode {
|
||||
|
||||
@@ -10,6 +10,22 @@ const WINDOWS_BLUR_FACTOR = 1.00;
|
||||
|
||||
const clamp = (value: number, min: number, max: number) => Math.min(max, Math.max(min, value));
|
||||
|
||||
export interface AppearanceSettingsLike {
|
||||
enabled?: boolean;
|
||||
opacity?: number;
|
||||
blur?: number;
|
||||
}
|
||||
|
||||
export const resolveAppearanceValues = (appearance: AppearanceSettingsLike | undefined): { opacity: number; blur: number } => {
|
||||
if (!appearance || appearance.enabled !== false) {
|
||||
return {
|
||||
opacity: appearance?.opacity ?? DEFAULT_OPACITY,
|
||||
blur: appearance?.blur ?? 0,
|
||||
};
|
||||
}
|
||||
return { opacity: DEFAULT_OPACITY, blur: 0 };
|
||||
};
|
||||
|
||||
export const isMacLikePlatform = (): boolean => {
|
||||
if (typeof navigator === 'undefined') {
|
||||
return false;
|
||||
|
||||
27
frontend/src/utils/overlayWorkbenchTheme.test.ts
Normal file
27
frontend/src/utils/overlayWorkbenchTheme.test.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { buildOverlayWorkbenchTheme } from './overlayWorkbenchTheme';
|
||||
|
||||
const assertEqual = (actual: unknown, expected: unknown, message: string) => {
|
||||
if (actual !== expected) {
|
||||
throw new Error(`${message}\nactual: ${String(actual)}\nexpected: ${String(expected)}`);
|
||||
}
|
||||
};
|
||||
|
||||
const assertMatch = (value: string, pattern: RegExp, message: string) => {
|
||||
if (!pattern.test(value)) {
|
||||
throw new Error(`${message}\nactual: ${value}\npattern: ${String(pattern)}`);
|
||||
}
|
||||
};
|
||||
|
||||
const darkTheme = buildOverlayWorkbenchTheme(true);
|
||||
assertEqual(darkTheme.isDark, true, 'dark 主题标记应为 true');
|
||||
assertMatch(darkTheme.shellBg, /rgba\(15, 15, 17,/, 'dark 弹层背景应保持中性黑');
|
||||
assertMatch(darkTheme.sectionBg, /rgba\(255,?\s*255,?\s*255,?\s*0\.03\)/, 'dark section 背景透明度应匹配');
|
||||
assertEqual(darkTheme.iconColor, '#ffd666', 'dark 图标色应为金色强调');
|
||||
|
||||
const lightTheme = buildOverlayWorkbenchTheme(false);
|
||||
assertEqual(lightTheme.isDark, false, 'light 主题标记应为 false');
|
||||
assertMatch(lightTheme.shellBg, /rgba\(255,255,255,0\.98\)/, 'light 弹层背景透明度应匹配');
|
||||
assertMatch(lightTheme.sectionBg, /rgba\(255,?\s*255,?\s*255,?\s*0\.84\)/, 'light section 背景透明度应匹配');
|
||||
assertEqual(lightTheme.iconColor, '#1677ff', 'light 图标色应为蓝色强调');
|
||||
|
||||
console.log('overlayWorkbenchTheme tests passed');
|
||||
59
frontend/src/utils/overlayWorkbenchTheme.ts
Normal file
59
frontend/src/utils/overlayWorkbenchTheme.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
type OverlayWorkbenchTheme = {
|
||||
isDark: boolean;
|
||||
shellBg: string;
|
||||
shellBorder: string;
|
||||
shellShadow: string;
|
||||
shellBackdropFilter: string;
|
||||
sectionBg: string;
|
||||
sectionBorder: string;
|
||||
mutedText: string;
|
||||
titleText: string;
|
||||
iconBg: string;
|
||||
iconColor: string;
|
||||
hoverBg: string;
|
||||
selectedBg: string;
|
||||
selectedText: string;
|
||||
divider: string;
|
||||
};
|
||||
|
||||
export const buildOverlayWorkbenchTheme = (darkMode: boolean): OverlayWorkbenchTheme => {
|
||||
if (darkMode) {
|
||||
return {
|
||||
isDark: true,
|
||||
shellBg: 'linear-gradient(180deg, rgba(15, 15, 17, 0.96) 0%, rgba(11, 11, 13, 0.98) 100%)',
|
||||
shellBorder: '1px solid rgba(255,255,255,0.08)',
|
||||
shellShadow: '0 24px 56px rgba(0,0,0,0.34)',
|
||||
shellBackdropFilter: 'blur(18px)',
|
||||
sectionBg: 'rgba(255,255,255,0.03)',
|
||||
sectionBorder: '1px solid rgba(255,255,255,0.08)',
|
||||
mutedText: 'rgba(255,255,255,0.5)',
|
||||
titleText: '#f5f7ff',
|
||||
iconBg: 'rgba(255,214,102,0.12)',
|
||||
iconColor: '#ffd666',
|
||||
hoverBg: 'rgba(255,214,102,0.10)',
|
||||
selectedBg: 'rgba(255,214,102,0.14)',
|
||||
selectedText: '#ffd666',
|
||||
divider: 'rgba(255,255,255,0.08)',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
isDark: false,
|
||||
shellBg: 'linear-gradient(180deg, rgba(255,255,255,0.98) 0%, rgba(246,248,252,0.98) 100%)',
|
||||
shellBorder: '1px solid rgba(16,24,40,0.08)',
|
||||
shellShadow: '0 18px 42px rgba(15,23,42,0.12)',
|
||||
shellBackdropFilter: 'none',
|
||||
sectionBg: 'rgba(255,255,255,0.84)',
|
||||
sectionBorder: '1px solid rgba(16,24,40,0.08)',
|
||||
mutedText: 'rgba(16,24,40,0.55)',
|
||||
titleText: '#162033',
|
||||
iconBg: 'rgba(24,144,255,0.1)',
|
||||
iconColor: '#1677ff',
|
||||
hoverBg: 'rgba(24,144,255,0.08)',
|
||||
selectedBg: 'rgba(24,144,255,0.12)',
|
||||
selectedText: '#1677ff',
|
||||
divider: 'rgba(16,24,40,0.08)',
|
||||
};
|
||||
};
|
||||
|
||||
export type { OverlayWorkbenchTheme };
|
||||
@@ -50,6 +50,11 @@ export const quoteIdentPart = (dbType: string, ident: string) => {
|
||||
return raw;
|
||||
}
|
||||
|
||||
// SQL Server 使用 [bracket] 标识符
|
||||
if (dbTypeLower === 'sqlserver' || dbTypeLower === 'mssql') {
|
||||
return `[${raw.replace(/]/g, ']]')}]`;
|
||||
}
|
||||
|
||||
// 其他数据库默认加双引号
|
||||
return `"${raw.replace(/"/g, '""')}"`;
|
||||
};
|
||||
@@ -134,6 +139,42 @@ export const buildOrderBySQL = (
|
||||
return '';
|
||||
};
|
||||
|
||||
export const buildPaginatedSelectSQL = (
|
||||
dbType: string,
|
||||
baseSql: string,
|
||||
orderBySQL: string,
|
||||
limit: number,
|
||||
offset: number,
|
||||
) => {
|
||||
const normalizedType = String(dbType || '').trim().toLowerCase();
|
||||
const safeLimit = Math.max(0, Math.floor(Number(limit) || 0));
|
||||
const safeOffset = Math.max(0, Math.floor(Number(offset) || 0));
|
||||
const base = String(baseSql || '').trim();
|
||||
const orderBy = String(orderBySQL || '');
|
||||
|
||||
if (!base || safeLimit <= 0) {
|
||||
return `${base}${orderBy}`;
|
||||
}
|
||||
|
||||
switch (normalizedType) {
|
||||
case 'oracle': {
|
||||
const orderedSql = `${base}${orderBy}`;
|
||||
const upperBound = safeOffset + safeLimit;
|
||||
if (safeOffset <= 0) {
|
||||
return `SELECT * FROM (${orderedSql}) WHERE ROWNUM <= ${upperBound}`;
|
||||
}
|
||||
return `SELECT * FROM (SELECT "__gonavi_page__".*, ROWNUM "__gonavi_rn__" FROM (${orderedSql}) "__gonavi_page__" WHERE ROWNUM <= ${upperBound}) WHERE "__gonavi_rn__" > ${safeOffset}`;
|
||||
}
|
||||
case 'sqlserver':
|
||||
case 'mssql': {
|
||||
const effectiveOrderBy = orderBy.trim() ? orderBy : ' ORDER BY (SELECT NULL)';
|
||||
return `${base}${effectiveOrderBy} OFFSET ${safeOffset} ROWS FETCH NEXT ${safeLimit} ROWS ONLY`;
|
||||
}
|
||||
default:
|
||||
return `${base}${orderBy} LIMIT ${safeLimit} OFFSET ${safeOffset}`;
|
||||
}
|
||||
};
|
||||
|
||||
export const parseListValues = (val: string) => {
|
||||
const raw = (val || '').trim();
|
||||
if (!raw) return [];
|
||||
|
||||
2
frontend/vite.config.d.ts
vendored
Normal file
2
frontend/vite.config.d.ts
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
declare const _default: import("vite").UserConfig;
|
||||
export default _default;
|
||||
15
frontend/vite.config.js
Normal file
15
frontend/vite.config.js
Normal file
@@ -0,0 +1,15 @@
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
host: '127.0.0.1',
|
||||
port: 5173,
|
||||
strictPort: true,
|
||||
},
|
||||
build: {
|
||||
outDir: 'dist', // Standard Wails output directory
|
||||
emptyOutDir: true,
|
||||
}
|
||||
});
|
||||
@@ -5,6 +5,7 @@ import react from '@vitejs/plugin-react'
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
host: '127.0.0.1',
|
||||
port: 5173,
|
||||
strictPort: true,
|
||||
},
|
||||
|
||||
10
frontend/wailsjs/go/app/App.d.ts
vendored
10
frontend/wailsjs/go/app/App.d.ts
vendored
@@ -9,6 +9,8 @@ export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:s
|
||||
|
||||
export function CancelQuery(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function CancelSQLFileExecution(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function CheckDriverNetworkStatus():Promise<connection.QueryResult>;
|
||||
|
||||
export function CheckForUpdates():Promise<connection.QueryResult>;
|
||||
@@ -41,6 +43,8 @@ export function DBQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string
|
||||
|
||||
export function DBQueryIsolated(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DBQueryMulti(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DBQueryWithCancel(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
@@ -63,6 +67,8 @@ export function DropTable(arg1:connection.ConnectionConfig,arg2:string,arg3:stri
|
||||
|
||||
export function DropView(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExecuteSQLFile(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise<connection.QueryResult>;
|
||||
@@ -131,6 +137,8 @@ export function RedisGetServerInfo(arg1:connection.ConnectionConfig):Promise<con
|
||||
|
||||
export function RedisGetValue(arg1:connection.ConnectionConfig,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisKeyExists(arg1:connection.ConnectionConfig,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisListPush(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisListSet(arg1:connection.ConnectionConfig,arg2:string,arg3:number,arg4:string):Promise<connection.QueryResult>;
|
||||
@@ -188,3 +196,5 @@ export function SelectSSHKeyFile(arg1:string):Promise<connection.QueryResult>;
|
||||
export function SetWindowTranslucency(arg1:number,arg2:number):Promise<void>;
|
||||
|
||||
export function TestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function TruncateTables(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -10,6 +10,10 @@ export function CancelQuery(arg1) {
|
||||
return window['go']['app']['App']['CancelQuery'](arg1);
|
||||
}
|
||||
|
||||
export function CancelSQLFileExecution(arg1) {
|
||||
return window['go']['app']['App']['CancelSQLFileExecution'](arg1);
|
||||
}
|
||||
|
||||
export function CheckDriverNetworkStatus() {
|
||||
return window['go']['app']['App']['CheckDriverNetworkStatus']();
|
||||
}
|
||||
@@ -74,6 +78,10 @@ export function DBQueryIsolated(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['DBQueryIsolated'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function DBQueryMulti(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['DBQueryMulti'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function DBQueryWithCancel(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['DBQueryWithCancel'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@@ -118,6 +126,10 @@ export function DropView(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['DropView'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function ExecuteSQLFile(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExecuteSQLFile'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ExportData(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@@ -254,6 +266,10 @@ export function RedisGetValue(arg1, arg2) {
|
||||
return window['go']['app']['App']['RedisGetValue'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function RedisKeyExists(arg1, arg2) {
|
||||
return window['go']['app']['App']['RedisKeyExists'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function RedisListPush(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisListPush'](arg1, arg2, arg3);
|
||||
}
|
||||
@@ -369,3 +385,7 @@ export function SetWindowTranslucency(arg1, arg2) {
|
||||
export function TestConnection(arg1) {
|
||||
return window['go']['app']['App']['TestConnection'](arg1);
|
||||
}
|
||||
|
||||
export function TruncateTables(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['TruncateTables'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
@@ -48,6 +48,24 @@ export namespace connection {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
export class HTTPTunnelConfig {
|
||||
host: string;
|
||||
port: number;
|
||||
user?: string;
|
||||
password?: string;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new HTTPTunnelConfig(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.host = source["host"];
|
||||
this.port = source["port"];
|
||||
this.user = source["user"];
|
||||
this.password = source["password"];
|
||||
}
|
||||
}
|
||||
export class ProxyConfig {
|
||||
type: string;
|
||||
host: string;
|
||||
@@ -104,6 +122,8 @@ export namespace connection {
|
||||
ssh: SSHConfig;
|
||||
useProxy?: boolean;
|
||||
proxy?: ProxyConfig;
|
||||
useHttpTunnel?: boolean;
|
||||
httpTunnel?: HTTPTunnelConfig;
|
||||
driver?: string;
|
||||
dsn?: string;
|
||||
timeout?: number;
|
||||
@@ -142,6 +162,8 @@ export namespace connection {
|
||||
this.ssh = this.convertValues(source["ssh"], SSHConfig);
|
||||
this.useProxy = source["useProxy"];
|
||||
this.proxy = this.convertValues(source["proxy"], ProxyConfig);
|
||||
this.useHttpTunnel = source["useHttpTunnel"];
|
||||
this.httpTunnel = this.convertValues(source["httpTunnel"], HTTPTunnelConfig);
|
||||
this.driver = source["driver"];
|
||||
this.dsn = source["dsn"];
|
||||
this.timeout = source["timeout"];
|
||||
@@ -179,6 +201,7 @@ export namespace connection {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export class QueryResult {
|
||||
success: boolean;
|
||||
message: string;
|
||||
@@ -254,6 +277,9 @@ export namespace sync {
|
||||
mode: string;
|
||||
jobId?: string;
|
||||
autoAddColumns?: boolean;
|
||||
targetTableStrategy?: string;
|
||||
createIndexes?: boolean;
|
||||
mongoCollectionName?: string;
|
||||
tableOptions?: Record<string, TableOptions>;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
@@ -269,6 +295,9 @@ export namespace sync {
|
||||
this.mode = source["mode"];
|
||||
this.jobId = source["jobId"];
|
||||
this.autoAddColumns = source["autoAddColumns"];
|
||||
this.targetTableStrategy = source["targetTableStrategy"];
|
||||
this.createIndexes = source["createIndexes"];
|
||||
this.mongoCollectionName = source["mongoCollectionName"];
|
||||
this.tableOptions = this.convertValues(source["tableOptions"], TableOptions, true);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -96,6 +98,9 @@ func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.Conn
|
||||
if !normalized.UseProxy {
|
||||
normalized.Proxy = connection.ProxyConfig{}
|
||||
}
|
||||
if !normalized.UseHTTPTunnel {
|
||||
normalized.HTTPTunnel = connection.HTTPTunnelConfig{}
|
||||
}
|
||||
|
||||
if isFileDatabaseType(normalized.Type) {
|
||||
dsn := strings.TrimSpace(normalized.Host)
|
||||
@@ -124,6 +129,8 @@ func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.Conn
|
||||
normalized.MongoAuthMechanism = ""
|
||||
normalized.MongoReplicaUser = ""
|
||||
normalized.MongoReplicaPassword = ""
|
||||
normalized.UseHTTPTunnel = false
|
||||
normalized.HTTPTunnel = connection.HTTPTunnelConfig{}
|
||||
}
|
||||
|
||||
return normalized
|
||||
@@ -213,6 +220,7 @@ func wrapConnectError(config connection.ConnectionConfig, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
err = sanitizeMongoConnectErrorLabel(config, err)
|
||||
|
||||
var netErr net.Error
|
||||
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) {
|
||||
@@ -226,6 +234,73 @@ func wrapConnectError(config connection.ConnectionConfig, err error) error {
|
||||
return withLogHint{err: err, logPath: logger.Path()}
|
||||
}
|
||||
|
||||
type errorMessageOverride struct {
|
||||
message string
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e errorMessageOverride) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e errorMessageOverride) Unwrap() error {
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func sanitizeMongoConnectErrorLabel(config connection.ConnectionConfig, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(config.Type)) != "mongodb" {
|
||||
return err
|
||||
}
|
||||
if mongoConnectUsesTLS(config) {
|
||||
return err
|
||||
}
|
||||
original := err.Error()
|
||||
rewritten := strings.ReplaceAll(original, "SSL 主库凭据", "主库凭据")
|
||||
rewritten = strings.ReplaceAll(rewritten, "SSL 从库凭据", "从库凭据")
|
||||
if rewritten == original {
|
||||
return err
|
||||
}
|
||||
return errorMessageOverride{
|
||||
message: rewritten,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func mongoConnectUsesTLS(config connection.ConnectionConfig) bool {
|
||||
if config.UseSSL {
|
||||
return true
|
||||
}
|
||||
uriText := strings.TrimSpace(config.URI)
|
||||
if uriText == "" {
|
||||
return false
|
||||
}
|
||||
parsed, err := url.Parse(uriText)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, key := range []string{"tls", "ssl"} {
|
||||
if enabled, known := parseMongoBool(parsed.Query().Get(key)); known {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(parsed.Scheme), "mongodb+srv")
|
||||
}
|
||||
|
||||
func parseMongoBool(raw string) (enabled bool, known bool) {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch value {
|
||||
case "1", "true", "t", "yes", "y", "on", "required":
|
||||
return true, true
|
||||
case "0", "false", "f", "no", "n", "off", "disable", "disabled":
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
type withLogHint struct {
|
||||
err error
|
||||
logPath string
|
||||
@@ -233,10 +308,15 @@ type withLogHint struct {
|
||||
|
||||
func (e withLogHint) Error() string {
|
||||
message := normalizeErrorMessage(e.err)
|
||||
if strings.TrimSpace(e.logPath) == "" {
|
||||
path := strings.TrimSpace(e.logPath)
|
||||
if path == "" {
|
||||
return message
|
||||
}
|
||||
return fmt.Sprintf("%s(详细日志:%s)", message, e.logPath)
|
||||
info, statErr := os.Stat(path)
|
||||
if statErr != nil || info.IsDir() || info.Size() <= 0 {
|
||||
return message
|
||||
}
|
||||
return fmt.Sprintf("%s(详细日志:%s)", message, path)
|
||||
}
|
||||
|
||||
func (e withLogHint) Unwrap() error {
|
||||
@@ -303,6 +383,12 @@ func formatConnSummary(config connection.ConnectionConfig) string {
|
||||
b.WriteString(" 代理认证=已配置")
|
||||
}
|
||||
}
|
||||
if config.UseHTTPTunnel {
|
||||
b.WriteString(fmt.Sprintf(" HTTP隧道=%s:%d", strings.TrimSpace(config.HTTPTunnel.Host), config.HTTPTunnel.Port))
|
||||
if strings.TrimSpace(config.HTTPTunnel.User) != "" {
|
||||
b.WriteString(" HTTP隧道认证=已配置")
|
||||
}
|
||||
}
|
||||
|
||||
if config.Type == "custom" {
|
||||
driver := strings.TrimSpace(config.Driver)
|
||||
|
||||
84
internal/app/app_connect_error_test.go
Normal file
84
internal/app/app_connect_error_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestWrapConnectError_MongoNoSSL_RemovesMisleadingSSLLabel(t *testing.T) {
|
||||
config := connection.ConnectionConfig{
|
||||
Type: "mongodb",
|
||||
UseSSL: false,
|
||||
}
|
||||
sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error")
|
||||
|
||||
wrapped := wrapConnectError(config, sourceErr)
|
||||
text := wrapped.Error()
|
||||
if strings.Contains(text, "SSL 主库凭据") {
|
||||
t.Fatalf("expected ssl label to be removed when TLS disabled, got: %s", text)
|
||||
}
|
||||
if !strings.Contains(text, "主库凭据验证失败") {
|
||||
t.Fatalf("expected auth label to remain, got: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapConnectError_MongoURIForcesTLS_KeepsSSLLabel(t *testing.T) {
|
||||
config := connection.ConnectionConfig{
|
||||
Type: "mongodb",
|
||||
UseSSL: false,
|
||||
URI: "mongodb://user:pass@127.0.0.1:27017/admin?tls=true",
|
||||
}
|
||||
sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error")
|
||||
|
||||
wrapped := wrapConnectError(config, sourceErr)
|
||||
text := wrapped.Error()
|
||||
if !strings.Contains(text, "SSL 主库凭据") {
|
||||
t.Fatalf("expected ssl label to remain when URI enables TLS, got: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapConnectError_MongoSRVDefaultTLS_KeepsSSLLabel(t *testing.T) {
|
||||
config := connection.ConnectionConfig{
|
||||
Type: "mongodb",
|
||||
UseSSL: false,
|
||||
URI: "mongodb+srv://user:pass@cluster0.example.com/admin",
|
||||
}
|
||||
sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error")
|
||||
|
||||
wrapped := wrapConnectError(config, sourceErr)
|
||||
text := wrapped.Error()
|
||||
if !strings.Contains(text, "SSL 主库凭据") {
|
||||
t.Fatalf("expected ssl label to remain for mongodb+srv default TLS, got: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogHintError_OmitEmptyLogPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logPath := filepath.Join(dir, "gonavi.log")
|
||||
if err := os.WriteFile(logPath, nil, 0o644); err != nil {
|
||||
t.Fatalf("write empty log failed: %v", err)
|
||||
}
|
||||
err := withLogHint{err: errors.New("连接失败"), logPath: logPath}
|
||||
text := err.Error()
|
||||
if strings.Contains(text, "详细日志:") {
|
||||
t.Fatalf("expected no log hint for empty file, got: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogHintError_IncludeNonEmptyLogPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logPath := filepath.Join(dir, "gonavi.log")
|
||||
if err := os.WriteFile(logPath, []byte("log entry\n"), 0o644); err != nil {
|
||||
t.Fatalf("write log failed: %v", err)
|
||||
}
|
||||
err := withLogHint{err: errors.New("连接失败"), logPath: logPath}
|
||||
text := err.Error()
|
||||
if !strings.Contains(text, "详细日志:"+logPath) {
|
||||
t.Fatalf("expected log hint with path, got: %s", text)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -20,6 +21,11 @@ func normalizeRunConfig(config connection.ConnectionConfig, dbName string) conne
|
||||
case "dameng":
|
||||
// 达梦使用 schema 参数,沿用现有行为:dbName 表示 schema。
|
||||
runConfig.Database = name
|
||||
case "redis":
|
||||
runConfig.Database = name
|
||||
if idx, err := strconv.Atoi(name); err == nil && idx >= 0 && idx <= 15 {
|
||||
runConfig.RedisDB = idx
|
||||
}
|
||||
default:
|
||||
// oracle: dbName 表示 schema/owner,不能覆盖 config.Database(服务名)
|
||||
// sqlite: 无需设置 Database
|
||||
|
||||
@@ -12,8 +12,35 @@ import (
|
||||
|
||||
func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
config := raw
|
||||
if config.UseHTTPTunnel {
|
||||
if config.UseProxy {
|
||||
return connection.ConnectionConfig{}, fmt.Errorf("HTTP 隧道与普通代理不能同时启用")
|
||||
}
|
||||
tunnelHost := strings.TrimSpace(config.HTTPTunnel.Host)
|
||||
if tunnelHost == "" {
|
||||
return connection.ConnectionConfig{}, fmt.Errorf("HTTP 隧道主机不能为空")
|
||||
}
|
||||
tunnelPort := config.HTTPTunnel.Port
|
||||
if tunnelPort <= 0 {
|
||||
tunnelPort = 8080
|
||||
}
|
||||
if tunnelPort > 65535 {
|
||||
return connection.ConnectionConfig{}, fmt.Errorf("HTTP 隧道端口无效:%d", config.HTTPTunnel.Port)
|
||||
}
|
||||
|
||||
config.UseProxy = true
|
||||
config.Proxy = connection.ProxyConfig{
|
||||
Type: "http",
|
||||
Host: tunnelHost,
|
||||
Port: tunnelPort,
|
||||
User: strings.TrimSpace(config.HTTPTunnel.User),
|
||||
Password: config.HTTPTunnel.Password,
|
||||
}
|
||||
}
|
||||
if !config.UseProxy {
|
||||
config.Proxy = connection.ProxyConfig{}
|
||||
config.UseHTTPTunnel = false
|
||||
config.HTTPTunnel = connection.HTTPTunnelConfig{}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -22,6 +49,8 @@ func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.Con
|
||||
return connection.ConnectionConfig{}, err
|
||||
}
|
||||
config.Proxy = normalizedProxy
|
||||
config.UseHTTPTunnel = false
|
||||
config.HTTPTunnel = connection.HTTPTunnelConfig{}
|
||||
|
||||
if config.UseSSH {
|
||||
sshPort := config.SSH.Port
|
||||
@@ -44,8 +73,8 @@ func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.Con
|
||||
// 文件型/自定义 DSN 类型不走标准 host:port,不在此层改写。
|
||||
return config, nil
|
||||
}
|
||||
if normalizedType == "mongodb" && config.MongoSRV {
|
||||
// Mongo SRV 由驱动侧 Dialer 处理代理,避免破坏 DNS SRV 拓扑发现。
|
||||
if normalizedType == "mongodb" {
|
||||
// MongoDB 统一由驱动侧 Dialer 处理代理,保留原始目标地址,避免将连接目标改写为本地转发地址。
|
||||
return config, nil
|
||||
}
|
||||
|
||||
|
||||
64
internal/app/db_proxy_test.go
Normal file
64
internal/app/db_proxy_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestResolveDialConfigWithProxy_MongoKeepsTargetAddress(t *testing.T) {
|
||||
hosts := []string{"10.20.30.40:27017", "10.20.30.41:27017"}
|
||||
raw := connection.ConnectionConfig{
|
||||
Type: "mongodb",
|
||||
Host: "10.20.30.40",
|
||||
Port: 27017,
|
||||
UseProxy: true,
|
||||
Proxy: connection.ProxyConfig{
|
||||
Type: "socks5",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1080,
|
||||
},
|
||||
Hosts: hosts,
|
||||
}
|
||||
|
||||
got, err := resolveDialConfigWithProxy(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveDialConfigWithProxy returned error: %v", err)
|
||||
}
|
||||
if got.Host != raw.Host || got.Port != raw.Port {
|
||||
t.Fatalf("mongo target address should be kept, got=%s:%d want=%s:%d", got.Host, got.Port, raw.Host, raw.Port)
|
||||
}
|
||||
if !got.UseProxy {
|
||||
t.Fatalf("mongo should keep UseProxy=true for driver-level dialer")
|
||||
}
|
||||
if !reflect.DeepEqual(got.Hosts, hosts) {
|
||||
t.Fatalf("mongo hosts should be kept, got=%v want=%v", got.Hosts, hosts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDialConfigWithProxy_MongoSRVKeepsTargetAddress(t *testing.T) {
|
||||
raw := connection.ConnectionConfig{
|
||||
Type: "mongodb",
|
||||
Host: "cluster0.example.com",
|
||||
Port: 27017,
|
||||
MongoSRV: true,
|
||||
UseProxy: true,
|
||||
Proxy: connection.ProxyConfig{
|
||||
Type: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 7890,
|
||||
},
|
||||
}
|
||||
|
||||
got, err := resolveDialConfigWithProxy(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveDialConfigWithProxy returned error: %v", err)
|
||||
}
|
||||
if got.Host != raw.Host || got.Port != raw.Port {
|
||||
t.Fatalf("mongo SRV target address should be kept, got=%s:%d want=%s:%d", got.Host, got.Port, raw.Host, raw.Port)
|
||||
}
|
||||
if !got.UseProxy {
|
||||
t.Fatalf("mongo SRV should keep UseProxy=true for driver-level dialer")
|
||||
}
|
||||
}
|
||||
@@ -72,25 +72,30 @@ func setGlobalProxyConfig(enabled bool, proxyConfig connection.ProxyConfig) (glo
|
||||
}
|
||||
|
||||
func (a *App) ConfigureGlobalProxy(enabled bool, proxyConfig connection.ProxyConfig) connection.QueryResult {
|
||||
before := currentGlobalProxyConfig()
|
||||
snapshot, err := setGlobalProxyConfig(enabled, proxyConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if snapshot.Enabled {
|
||||
authState := ""
|
||||
if strings.TrimSpace(snapshot.Proxy.User) != "" {
|
||||
authState = "(认证:已配置)"
|
||||
// 前端可能在同一配置下重复触发同步(例如严格模式或状态回放),
|
||||
// 这里做幂等日志,避免重复刷屏。
|
||||
if !globalProxySnapshotEqual(before, snapshot) {
|
||||
if snapshot.Enabled {
|
||||
authState := ""
|
||||
if strings.TrimSpace(snapshot.Proxy.User) != "" {
|
||||
authState = "(认证:已配置)"
|
||||
}
|
||||
logger.Infof(
|
||||
"全局代理已启用:%s://%s:%d%s",
|
||||
strings.ToLower(strings.TrimSpace(snapshot.Proxy.Type)),
|
||||
strings.TrimSpace(snapshot.Proxy.Host),
|
||||
snapshot.Proxy.Port,
|
||||
authState,
|
||||
)
|
||||
} else {
|
||||
logger.Infof("全局代理已关闭")
|
||||
}
|
||||
logger.Infof(
|
||||
"全局代理已启用:%s://%s:%d%s",
|
||||
strings.ToLower(strings.TrimSpace(snapshot.Proxy.Type)),
|
||||
strings.TrimSpace(snapshot.Proxy.Host),
|
||||
snapshot.Proxy.Port,
|
||||
authState,
|
||||
)
|
||||
} else {
|
||||
logger.Infof("全局代理已关闭")
|
||||
}
|
||||
|
||||
return connection.QueryResult{
|
||||
@@ -100,6 +105,24 @@ func (a *App) ConfigureGlobalProxy(enabled bool, proxyConfig connection.ProxyCon
|
||||
}
|
||||
}
|
||||
|
||||
func globalProxySnapshotEqual(a, b globalProxySnapshot) bool {
|
||||
if a.Enabled != b.Enabled {
|
||||
return false
|
||||
}
|
||||
if !a.Enabled {
|
||||
return true
|
||||
}
|
||||
return proxyConfigEqual(a.Proxy, b.Proxy)
|
||||
}
|
||||
|
||||
func proxyConfigEqual(a, b connection.ProxyConfig) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) &&
|
||||
strings.TrimSpace(a.Host) == strings.TrimSpace(b.Host) &&
|
||||
a.Port == b.Port &&
|
||||
strings.TrimSpace(a.User) == strings.TrimSpace(b.User) &&
|
||||
a.Password == b.Password
|
||||
}
|
||||
|
||||
func (a *App) GetGlobalProxyConfig() connection.QueryResult {
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
@@ -110,7 +133,7 @@ func (a *App) GetGlobalProxyConfig() connection.QueryResult {
|
||||
|
||||
func applyGlobalProxyToConnection(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
effective := config
|
||||
if effective.UseProxy {
|
||||
if effective.UseProxy || effective.UseHTTPTunnel {
|
||||
return effective
|
||||
}
|
||||
if isFileDatabaseType(effective.Type) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -12,6 +13,16 @@ import (
|
||||
"GoNavi-Wails/internal/utils"
|
||||
)
|
||||
|
||||
const testConnectionTimeoutUpperBoundSeconds = 12
|
||||
|
||||
func normalizeTestConnectionConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
normalized := config
|
||||
if normalized.Timeout <= 0 || normalized.Timeout > testConnectionTimeoutUpperBoundSeconds {
|
||||
normalized.Timeout = testConnectionTimeoutUpperBoundSeconds
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// Generic DB Methods
|
||||
|
||||
func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResult {
|
||||
@@ -27,13 +38,16 @@ func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResu
|
||||
}
|
||||
|
||||
func (a *App) TestConnection(config connection.ConnectionConfig) connection.QueryResult {
|
||||
_, err := a.getDatabaseForcePing(config)
|
||||
testConfig := normalizeTestConnectionConfig(config)
|
||||
started := time.Now()
|
||||
logger.Infof("TestConnection 开始:%s", formatConnSummary(testConfig))
|
||||
_, err := a.getDatabaseForcePing(testConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "TestConnection 连接测试失败:%s", formatConnSummary(config))
|
||||
logger.Error(err, "TestConnection 连接测试失败:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
logger.Infof("TestConnection 连接测试成功:%s", formatConnSummary(config))
|
||||
logger.Infof("TestConnection 连接测试成功:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig))
|
||||
return connection.QueryResult{Success: true, Message: "连接成功"}
|
||||
}
|
||||
|
||||
@@ -102,7 +116,7 @@ func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Database created successfully"}
|
||||
return connection.QueryResult{Success: true, Message: "数据库创建成功"}
|
||||
}
|
||||
|
||||
func resolveDDLDBType(config connection.ConnectionConfig) string {
|
||||
@@ -416,12 +430,7 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
a.queryMu.Unlock()
|
||||
}()
|
||||
|
||||
lowerQuery := strings.TrimSpace(strings.ToLower(query))
|
||||
isReadQuery := strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain")
|
||||
// MongoDB JSON 命令中的 find/count/aggregate 也属于读查询
|
||||
if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
|
||||
isReadQuery = true
|
||||
}
|
||||
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
|
||||
|
||||
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
|
||||
if q, ok := inst.(interface {
|
||||
@@ -478,6 +487,151 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
}
|
||||
}
|
||||
|
||||
// DBQueryMulti 执行可能包含多条 SQL 语句的查询,返回多个结果集。
|
||||
// 如果底层驱动支持 MultiResultQuerier,一次性执行所有语句;
|
||||
// 否则按分号拆分后逐条执行,模拟多结果集。
|
||||
func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
if queryID == "" {
|
||||
queryID = generateQueryID()
|
||||
}
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
|
||||
query = sanitizeSQLForPgLike(runConfig.Type, query)
|
||||
timeoutSeconds := runConfig.Timeout
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 30
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
|
||||
defer cancel()
|
||||
|
||||
a.queryMu.Lock()
|
||||
a.runningQueries[queryID] = queryContext{
|
||||
cancel: cancel,
|
||||
started: time.Now(),
|
||||
}
|
||||
a.queryMu.Unlock()
|
||||
defer func() {
|
||||
a.queryMu.Lock()
|
||||
delete(a.runningQueries, queryID)
|
||||
a.queryMu.Unlock()
|
||||
}()
|
||||
|
||||
// 尝试使用驱动原生多结果集支持
|
||||
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, error) {
|
||||
if q, ok := inst.(db.MultiResultQuerierContext); ok {
|
||||
return q.QueryMultiContext(ctx, query)
|
||||
}
|
||||
if q, ok := inst.(db.MultiResultQuerier); ok {
|
||||
return q.QueryMulti(query)
|
||||
}
|
||||
return nil, nil // 返回 nil 表示不支持
|
||||
}
|
||||
|
||||
results, err := runMultiQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBQueryMulti 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error(), QueryID: queryID}
|
||||
}
|
||||
results, err = runMultiQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
|
||||
// 驱动支持多结果集,直接返回
|
||||
if results != nil {
|
||||
return connection.QueryResult{Success: true, Data: results, QueryID: queryID}
|
||||
}
|
||||
|
||||
// 驱动不支持多结果集,回退到逐条执行
|
||||
statements := splitSQLStatements(query)
|
||||
if len(statements) == 0 {
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
Data: []connection.ResultSetData{},
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
|
||||
var resultSets []connection.ResultSetData
|
||||
for idx, stmt := range statements {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if isReadOnlySQLQuery(runConfig.Type, stmt) {
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, stmt)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
if data == nil {
|
||||
data = make([]map[string]interface{}, 0)
|
||||
}
|
||||
if columns == nil {
|
||||
columns = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{Rows: data, Columns: columns})
|
||||
} else {
|
||||
var affected int64
|
||||
if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, stmt)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if resultSets == nil {
|
||||
resultSets = []connection.ResultSetData{}
|
||||
}
|
||||
// 回退到逐条执行且有多条语句时,附加提示信息
|
||||
var fallbackMsg string
|
||||
if len(statements) > 1 {
|
||||
fallbackMsg = fmt.Sprintf("当前数据源(%s)不支持原生多语句执行,已自动拆分为 %d 条语句逐条执行。", runConfig.Type, len(statements))
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: resultSets, QueryID: queryID, Message: fallbackMsg}
|
||||
}
|
||||
|
||||
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
@@ -500,11 +654,7 @@ func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string,
|
||||
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
|
||||
defer cancel()
|
||||
|
||||
lowerQuery := strings.TrimSpace(strings.ToLower(query))
|
||||
isReadQuery := strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain")
|
||||
if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
|
||||
isReadQuery = true
|
||||
}
|
||||
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
|
||||
|
||||
if isReadQuery {
|
||||
var data []map[string]interface{}
|
||||
@@ -547,8 +697,33 @@ func sqlSnippet(query string) string {
|
||||
return q[:max] + "..."
|
||||
}
|
||||
|
||||
func ensureNonNilSlice[T any](items []T) []T {
|
||||
if items == nil {
|
||||
return make([]T, 0)
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, "")
|
||||
if strings.EqualFold(strings.TrimSpace(runConfig.Type), "redis") {
|
||||
runConfig.Type = "redis"
|
||||
client, err := a.getRedisClient(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetDatabases 获取 Redis 连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
dbs, err := client.GetDatabases()
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetDatabases 获取 Redis 库列表失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
resData := make([]map[string]string, 0, len(dbs))
|
||||
for _, item := range dbs {
|
||||
resData = append(resData, map[string]string{"Database": strconv.Itoa(item.Index)})
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: resData}
|
||||
}
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
@@ -571,7 +746,7 @@ func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.Quer
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
var resData []map[string]string
|
||||
resData := make([]map[string]string, 0, len(dbs))
|
||||
for _, name := range dbs {
|
||||
resData = append(resData, map[string]string{"Database": name})
|
||||
}
|
||||
@@ -581,6 +756,48 @@ func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.Quer
|
||||
|
||||
func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
if strings.EqualFold(strings.TrimSpace(runConfig.Type), "redis") {
|
||||
runConfig.Type = "redis"
|
||||
client, err := a.getRedisClient(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetTables 获取 Redis 连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
cursor := uint64(0)
|
||||
tables := make([]string, 0, 128)
|
||||
seen := make(map[string]struct{}, 128)
|
||||
for {
|
||||
result, err := client.ScanKeys("*", cursor, 1000)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetTables 扫描 Redis Key 失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
for _, item := range result.Keys {
|
||||
key := strings.TrimSpace(item.Key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
tables = append(tables, key)
|
||||
}
|
||||
if strings.TrimSpace(result.Cursor) == "" || strings.TrimSpace(result.Cursor) == "0" {
|
||||
break
|
||||
}
|
||||
next, err := strconv.ParseUint(strings.TrimSpace(result.Cursor), 10, 64)
|
||||
if err != nil || next == cursor {
|
||||
break
|
||||
}
|
||||
cursor = next
|
||||
}
|
||||
resData := make([]map[string]string, 0, len(tables))
|
||||
for _, name := range tables {
|
||||
resData = append(resData, map[string]string{"Table": name})
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: resData}
|
||||
}
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
@@ -604,7 +821,7 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
var resData []map[string]string
|
||||
resData := make([]map[string]string, 0, len(tables))
|
||||
for _, name := range tables {
|
||||
resData = append(resData, map[string]string{"Table": name})
|
||||
}
|
||||
@@ -786,7 +1003,7 @@ func (a *App) DBGetColumns(config connection.ConnectionConfig, dbName string, ta
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: columns}
|
||||
return connection.QueryResult{Success: true, Data: ensureNonNilSlice(columns)}
|
||||
}
|
||||
|
||||
func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
@@ -803,7 +1020,7 @@ func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, ta
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: indexes}
|
||||
return connection.QueryResult{Success: true, Data: ensureNonNilSlice(indexes)}
|
||||
}
|
||||
|
||||
func (a *App) DBGetForeignKeys(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
@@ -820,7 +1037,7 @@ func (a *App) DBGetForeignKeys(config connection.ConnectionConfig, dbName string
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: fks}
|
||||
return connection.QueryResult{Success: true, Data: ensureNonNilSlice(fks)}
|
||||
}
|
||||
|
||||
func (a *App) DBGetTriggers(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
@@ -837,7 +1054,7 @@ func (a *App) DBGetTriggers(config connection.ConnectionConfig, dbName string, t
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: triggers}
|
||||
return connection.QueryResult{Success: true, Data: ensureNonNilSlice(triggers)}
|
||||
}
|
||||
|
||||
func (a *App) DropView(config connection.ConnectionConfig, dbName string, viewName string) connection.QueryResult {
|
||||
@@ -975,5 +1192,5 @@ func (a *App) DBGetAllColumns(config connection.ConnectionConfig, dbName string)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: cols}
|
||||
return connection.QueryResult{Success: true, Data: ensureNonNilSlice(cols)}
|
||||
}
|
||||
|
||||
112
internal/app/methods_db_conn_test.go
Normal file
112
internal/app/methods_db_conn_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestNormalizeTestConnectionConfig_CapsTimeout(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{Timeout: 60}
|
||||
got := normalizeTestConnectionConfig(cfg)
|
||||
if got.Timeout != testConnectionTimeoutUpperBoundSeconds {
|
||||
t.Fatalf("timeout 应被限制为 %d, got=%d", testConnectionTimeoutUpperBoundSeconds, got.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTestConnectionConfig_KeepSmallTimeout(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{Timeout: 5}
|
||||
got := normalizeTestConnectionConfig(cfg)
|
||||
if got.Timeout != 5 {
|
||||
t.Fatalf("timeout 不应被修改, got=%d", got.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTestConnectionConfig_ZeroTimeout(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{Timeout: 0}
|
||||
got := normalizeTestConnectionConfig(cfg)
|
||||
if got.Timeout != testConnectionTimeoutUpperBoundSeconds {
|
||||
t.Fatalf("零值 timeout 应被修正, got=%d", got.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConnSummary_BasicMySQL(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "test_db",
|
||||
Timeout: 30,
|
||||
}
|
||||
got := formatConnSummary(cfg)
|
||||
for _, want := range []string{"类型=mysql", "127.0.0.1:3306", "test_db", "root"} {
|
||||
if !strings.Contains(got, want) {
|
||||
t.Fatalf("formatConnSummary 应包含 %q, got=%q", want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConnSummary_SQLitePath(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "sqlite",
|
||||
Host: "/data/test.db",
|
||||
}
|
||||
got := formatConnSummary(cfg)
|
||||
if !strings.Contains(got, "类型=sqlite") {
|
||||
t.Fatalf("formatConnSummary 缺少类型, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(got, "/data/test.db") {
|
||||
t.Fatalf("formatConnSummary 缺少路径, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConnSummary_SSH(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "db.internal",
|
||||
Port: 3306,
|
||||
User: "app",
|
||||
UseSSH: true,
|
||||
SSH: connection.SSHConfig{
|
||||
Host: "jump.server",
|
||||
Port: 22,
|
||||
User: "admin",
|
||||
},
|
||||
}
|
||||
got := formatConnSummary(cfg)
|
||||
if !strings.Contains(got, "SSH=jump.server:22") {
|
||||
t.Fatalf("formatConnSummary 应包含 SSH 信息, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConnSummary_Proxy(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "db.internal",
|
||||
Port: 3306,
|
||||
UseProxy: true,
|
||||
Proxy: connection.ProxyConfig{
|
||||
Type: "socks5",
|
||||
Host: "proxy.local",
|
||||
Port: 1080,
|
||||
},
|
||||
}
|
||||
got := formatConnSummary(cfg)
|
||||
if !strings.Contains(got, "代理=socks5://proxy.local:1080") {
|
||||
t.Fatalf("formatConnSummary 应包含代理信息, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatConnSummary_DefaultTimeout(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
}
|
||||
got := formatConnSummary(cfg)
|
||||
if !strings.Contains(got, "超时=30s") {
|
||||
t.Fatalf("formatConnSummary 默认超时应为30s, got=%q", got)
|
||||
}
|
||||
}
|
||||
31
internal/app/methods_db_timeout_test.go
Normal file
31
internal/app/methods_db_timeout_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestNormalizeTestConnectionConfig_DefaultToUpperBound(t *testing.T) {
|
||||
config := connection.ConnectionConfig{Type: "mongodb", Timeout: 0}
|
||||
got := normalizeTestConnectionConfig(config)
|
||||
if got.Timeout != testConnectionTimeoutUpperBoundSeconds {
|
||||
t.Fatalf("expected timeout=%d, got=%d", testConnectionTimeoutUpperBoundSeconds, got.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTestConnectionConfig_KeepSmallerTimeout(t *testing.T) {
|
||||
config := connection.ConnectionConfig{Type: "mongodb", Timeout: 6}
|
||||
got := normalizeTestConnectionConfig(config)
|
||||
if got.Timeout != 6 {
|
||||
t.Fatalf("expected timeout=6, got=%d", got.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTestConnectionConfig_ClampLargeTimeout(t *testing.T) {
|
||||
config := connection.ConnectionConfig{Type: "mongodb", Timeout: 60}
|
||||
got := normalizeTestConnectionConfig(config)
|
||||
if got.Timeout != testConnectionTimeoutUpperBoundSeconds {
|
||||
t.Fatalf("expected timeout=%d, got=%d", testConnectionTimeoutUpperBoundSeconds, got.Timeout)
|
||||
}
|
||||
}
|
||||
@@ -353,7 +353,7 @@ func (a *App) SelectDriverDownloadDirectory(currentDir string) connection.QueryR
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if strings.TrimSpace(selection) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
resolved, err := resolveDriverDownloadDirectory(selection)
|
||||
@@ -392,7 +392,7 @@ func (a *App) SelectDriverPackageFile(currentPath string) connection.QueryResult
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if strings.TrimSpace(selection) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
if abs, err := filepath.Abs(selection); err == nil {
|
||||
@@ -423,7 +423,7 @@ func (a *App) SelectDriverPackageDirectory(currentPath string) connection.QueryR
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if strings.TrimSpace(selection) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
if abs, err := filepath.Abs(selection); err == nil {
|
||||
selection = abs
|
||||
@@ -2536,6 +2536,9 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa
|
||||
return installedDriverPackage{}, fmt.Errorf("导入本地驱动代理失败:%w", copyErr)
|
||||
}
|
||||
}
|
||||
if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil {
|
||||
return installedDriverPackage{}, validateErr
|
||||
}
|
||||
|
||||
hash, hashErr := hashFileSHA256(executablePath)
|
||||
if hashErr != nil {
|
||||
@@ -2789,15 +2792,19 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
displayName := resolveDriverDisplayName(definition)
|
||||
forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion)
|
||||
preferSourceBuildBeforeDownload := shouldPreferSourceBuildBeforeDownload(driverType, selectedVersion)
|
||||
skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion)
|
||||
|
||||
info, err := os.Stat(executablePath)
|
||||
if err == nil && !info.IsDir() {
|
||||
hash, hashErr := hashFileSHA256(executablePath)
|
||||
if hashErr != nil {
|
||||
return "", "", fmt.Errorf("读取已安装 %s 驱动代理摘要失败:%w", displayName, hashErr)
|
||||
if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil {
|
||||
_ = os.Remove(executablePath)
|
||||
} else {
|
||||
// 用户点击“安装/重装”时应强制刷新驱动代理,避免沿用旧二进制导致修复不生效。
|
||||
if removeErr := os.Remove(executablePath); removeErr != nil {
|
||||
return "", "", fmt.Errorf("清理已安装 %s 驱动代理失败:%w", displayName, removeErr)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("local://existing/%s-driver-agent", driverType), hash, nil
|
||||
}
|
||||
if err == nil && info.IsDir() {
|
||||
return "", "", fmt.Errorf("%s 驱动代理路径被目录占用:%s", displayName, executablePath)
|
||||
@@ -2814,6 +2821,10 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
|
||||
if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
|
||||
return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr)
|
||||
}
|
||||
if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil {
|
||||
_ = os.Remove(executablePath)
|
||||
return "", "", validateErr
|
||||
}
|
||||
hash, hashErr := hashFileSHA256(executablePath)
|
||||
if hashErr != nil {
|
||||
return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr)
|
||||
@@ -2823,6 +2834,22 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
|
||||
}
|
||||
|
||||
var downloadErrs []string
|
||||
var sourceBuildAttempted bool
|
||||
var sourceBuildErr error
|
||||
|
||||
if !forceSourceBuild && preferSourceBuildBeforeDownload {
|
||||
sourceBuildAttempted = true
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 16, 100, fmt.Sprintf("优先使用本地源码构建 %s 驱动代理", displayName))
|
||||
}
|
||||
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion)
|
||||
if buildErr == nil {
|
||||
return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil
|
||||
}
|
||||
sourceBuildErr = buildErr
|
||||
logger.Warnf("预先本地构建 %s 驱动代理失败,将继续尝试下载预编译包:%v", displayName, buildErr)
|
||||
}
|
||||
|
||||
if !forceSourceBuild {
|
||||
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
|
||||
if len(downloadURLs) > 0 {
|
||||
@@ -2855,9 +2882,15 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 92, 100, "未命中预编译包,尝试开发态本地构建")
|
||||
}
|
||||
|
||||
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion)
|
||||
if buildErr == nil {
|
||||
return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil
|
||||
var buildErr error
|
||||
if sourceBuildAttempted {
|
||||
buildErr = sourceBuildErr
|
||||
} else {
|
||||
hash, runErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion)
|
||||
buildErr = runErr
|
||||
if buildErr == nil {
|
||||
return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil
|
||||
}
|
||||
}
|
||||
|
||||
var parts []string
|
||||
@@ -2901,6 +2934,10 @@ func downloadOptionalDriverAgentBinary(a *App, definition driverDefinition, urlT
|
||||
if chmodErr := os.Chmod(executablePath, 0o755); chmodErr != nil && stdRuntime.GOOS != "windows" {
|
||||
return "", fmt.Errorf("设置代理权限失败:%w", chmodErr)
|
||||
}
|
||||
if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil {
|
||||
_ = os.Remove(executablePath)
|
||||
return "", validateErr
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
@@ -3009,6 +3046,10 @@ func downloadOptionalDriverAgentFromBundle(a *App, definition driverDefinition,
|
||||
if chmodErr := os.Chmod(executablePath, 0o755); chmodErr != nil && stdRuntime.GOOS != "windows" {
|
||||
return "", "", fmt.Errorf("设置驱动代理权限失败:%w", chmodErr)
|
||||
}
|
||||
if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil {
|
||||
_ = os.Remove(executablePath)
|
||||
return "", "", validateErr
|
||||
}
|
||||
hash, err := hashFileSHA256(executablePath)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("计算驱动代理摘要失败:%w", err)
|
||||
@@ -3067,12 +3108,25 @@ func shouldForceSourceBuildForVersion(driverType string, selectedVersion string)
|
||||
return resolveMongoDriverMajorFromVersion(selectedVersion) == 1
|
||||
}
|
||||
|
||||
func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool {
|
||||
if normalizeDriverType(driverType) != "mongodb" {
|
||||
func shouldPreferSourceBuildBeforeDownload(driverType string, selectedVersion string) bool {
|
||||
_ = selectedVersion
|
||||
switch normalizeDriverType(driverType) {
|
||||
case "kingbase":
|
||||
// 金仓迭代期优先本地源码构建,避免下载到旧版本预编译代理导致修复不生效。
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool {
|
||||
_ = selectedVersion
|
||||
return true
|
||||
switch normalizeDriverType(driverType) {
|
||||
case "mongodb", "kingbase":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func optionalDriverBuildTag(driverType string, selectedVersion string) (string, error) {
|
||||
@@ -3334,6 +3388,7 @@ func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL
|
||||
}
|
||||
|
||||
func findExistingOptionalDriverAgentCandidate(definition driverDefinition, targetPath string) (string, bool) {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
targetAbs, _ := filepath.Abs(targetPath)
|
||||
candidates := resolveOptionalDriverAgentCandidatePaths(definition)
|
||||
for _, candidate := range candidates {
|
||||
@@ -3349,9 +3404,13 @@ func findExistingOptionalDriverAgentCandidate(definition driverDefinition, targe
|
||||
continue
|
||||
}
|
||||
info, statErr := os.Stat(absPath)
|
||||
if statErr == nil && !info.IsDir() {
|
||||
return absPath, true
|
||||
if statErr != nil || info.IsDir() {
|
||||
continue
|
||||
}
|
||||
if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, absPath); validateErr != nil {
|
||||
continue
|
||||
}
|
||||
return absPath, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -48,7 +49,28 @@ func (a *App) OpenSQLFile() connection.QueryResult {
|
||||
}
|
||||
|
||||
if selection == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
// 检查文件大小
|
||||
const maxSQLFileSize int64 = 50 * 1024 * 1024 // 50MB
|
||||
fi, err := os.Stat(selection)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("无法读取文件信息: %v", err)}
|
||||
}
|
||||
|
||||
// 大文件:只返回文件路径和大小,不读取内容
|
||||
if fi.Size() > maxSQLFileSize {
|
||||
sizeMB := float64(fi.Size()) / (1024 * 1024)
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
Data: map[string]interface{}{
|
||||
"isLargeFile": true,
|
||||
"filePath": selection,
|
||||
"fileSize": fi.Size(),
|
||||
"fileSizeMB": fmt.Sprintf("%.1f", sizeMB),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(selection)
|
||||
@@ -59,6 +81,184 @@ func (a *App) OpenSQLFile() connection.QueryResult {
|
||||
return connection.QueryResult{Success: true, Data: string(content)}
|
||||
}
|
||||
|
||||
// ExecuteSQLFile 在后端流式读取并执行大 SQL 文件,通过事件推送进度。
|
||||
// 前端通过 EventsOn("sqlfile:progress", ...) 监听进度。
|
||||
func (a *App) ExecuteSQLFile(config connection.ConnectionConfig, dbName string, filePath string, jobID string) connection.QueryResult {
|
||||
if strings.TrimSpace(filePath) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "文件路径为空"}
|
||||
}
|
||||
if strings.TrimSpace(jobID) == "" {
|
||||
jobID = fmt.Sprintf("sqlfile-%d", time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
logger.Warnf("ExecuteSQLFile 开始:file=%s db=%s jobID=%s", filePath, dbName, jobID)
|
||||
|
||||
// 获取数据库连接
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "ExecuteSQLFile 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("无法打开文件: %v", err)}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// 获取文件大小用于计算进度
|
||||
fi, _ := f.Stat()
|
||||
totalSize := fi.Size()
|
||||
|
||||
// 设置取消上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
a.queryMu.Lock()
|
||||
a.runningQueries[jobID] = queryContext{
|
||||
cancel: cancel,
|
||||
started: time.Now(),
|
||||
}
|
||||
a.queryMu.Unlock()
|
||||
defer func() {
|
||||
a.queryMu.Lock()
|
||||
delete(a.runningQueries, jobID)
|
||||
a.queryMu.Unlock()
|
||||
}()
|
||||
|
||||
// 发送进度事件的辅助函数
|
||||
emitProgress := func(status string, executed, failed, total int, bytesRead int64, currentSQL string, errMsg string) {
|
||||
percent := 0.0
|
||||
if totalSize > 0 {
|
||||
percent = float64(bytesRead) / float64(totalSize) * 100
|
||||
if percent > 100 {
|
||||
percent = 100
|
||||
}
|
||||
}
|
||||
runtime.EventsEmit(a.ctx, "sqlfile:progress", map[string]interface{}{
|
||||
"jobId": jobID,
|
||||
"status": status,
|
||||
"executed": executed,
|
||||
"failed": failed,
|
||||
"total": total,
|
||||
"percent": percent,
|
||||
"bytesRead": bytesRead,
|
||||
"totalBytes": totalSize,
|
||||
"currentSQL": currentSQL,
|
||||
"error": errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
emitProgress("running", 0, 0, 0, 0, "", "")
|
||||
|
||||
// 使用 countingReader 追踪已读取字节数
|
||||
cr := &countingReader{r: f}
|
||||
|
||||
var executedCount int
|
||||
var failedCount int
|
||||
var errorLogs []string
|
||||
startTime := time.Now()
|
||||
|
||||
_, streamErr := streamSQLFile(cr, func(index int, stmt string) error {
|
||||
// 检查是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("已取消")
|
||||
default:
|
||||
}
|
||||
|
||||
// 执行语句
|
||||
_, execErr := dbInst.Exec(stmt)
|
||||
if execErr != nil {
|
||||
failedCount++
|
||||
snippet := stmt
|
||||
if len(snippet) > 200 {
|
||||
snippet = snippet[:200] + "..."
|
||||
}
|
||||
errLog := fmt.Sprintf("第 %d 条语句执行失败: %v\n SQL: %s", index+1, execErr, snippet)
|
||||
errorLogs = append(errorLogs, errLog)
|
||||
logger.Warnf("ExecuteSQLFile %s", errLog)
|
||||
} else {
|
||||
executedCount++
|
||||
}
|
||||
|
||||
// 每条语句执行后推送进度(但限频:每 100 条或每秒推一次)
|
||||
total := executedCount + failedCount
|
||||
if total%100 == 0 || total <= 10 {
|
||||
snippet := stmt
|
||||
if len(snippet) > 100 {
|
||||
snippet = snippet[:100] + "..."
|
||||
}
|
||||
emitProgress("running", executedCount, failedCount, total, cr.n, snippet, "")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
duration := time.Since(startTime)
|
||||
|
||||
if streamErr != nil && streamErr.Error() == "已取消" {
|
||||
emitProgress("cancelled", executedCount, failedCount, executedCount+failedCount, cr.n, "", "用户取消执行")
|
||||
logger.Warnf("ExecuteSQLFile 已取消:executed=%d failed=%d duration=%v", executedCount, failedCount, duration)
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("执行已取消。已执行 %d 条,失败 %d 条,耗时 %v。", executedCount, failedCount, duration.Round(time.Millisecond)),
|
||||
}
|
||||
}
|
||||
|
||||
if streamErr != nil {
|
||||
emitProgress("error", executedCount, failedCount, executedCount+failedCount, cr.n, "", streamErr.Error())
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("文件读取错误: %v。已执行 %d 条。", streamErr, executedCount),
|
||||
}
|
||||
}
|
||||
|
||||
emitProgress("done", executedCount, failedCount, executedCount+failedCount, totalSize, "", "")
|
||||
|
||||
summary := fmt.Sprintf("执行完成。成功 %d 条,失败 %d 条,耗时 %v。", executedCount, failedCount, duration.Round(time.Millisecond))
|
||||
if len(errorLogs) > 0 {
|
||||
maxShow := 20
|
||||
if len(errorLogs) < maxShow {
|
||||
maxShow = len(errorLogs)
|
||||
}
|
||||
summary += "\n\n错误详情(前 " + fmt.Sprintf("%d", maxShow) + " 条):\n" + strings.Join(errorLogs[:maxShow], "\n")
|
||||
if len(errorLogs) > maxShow {
|
||||
summary += fmt.Sprintf("\n...还有 %d 条错误未显示", len(errorLogs)-maxShow)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Warnf("ExecuteSQLFile 完成:executed=%d failed=%d duration=%v", executedCount, failedCount, duration)
|
||||
return connection.QueryResult{Success: failedCount == 0, Message: summary}
|
||||
}
|
||||
|
||||
// CancelSQLFileExecution 取消正在执行的 SQL 文件任务。
|
||||
func (a *App) CancelSQLFileExecution(jobID string) connection.QueryResult {
|
||||
a.queryMu.Lock()
|
||||
defer a.queryMu.Unlock()
|
||||
|
||||
if ctx, exists := a.runningQueries[jobID]; exists {
|
||||
ctx.cancel()
|
||||
delete(a.runningQueries, jobID)
|
||||
return connection.QueryResult{Success: true, Message: "已发送取消请求"}
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: "未找到该任务"}
|
||||
}
|
||||
|
||||
// countingReader 包装 io.Reader,追踪已读取的字节数。
|
||||
type countingReader struct {
|
||||
r io.Reader
|
||||
n int64
|
||||
}
|
||||
|
||||
func (cr *countingReader) Read(p []byte) (int, error) {
|
||||
n, err := cr.r.Read(p)
|
||||
cr.n += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (a *App) ImportConfigFile() connection.QueryResult {
|
||||
selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{
|
||||
Title: "Select Config File",
|
||||
@@ -75,7 +275,7 @@ func (a *App) ImportConfigFile() connection.QueryResult {
|
||||
}
|
||||
|
||||
if selection == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(selection)
|
||||
@@ -120,7 +320,7 @@ func (a *App) SelectSSHKeyFile(currentPath string) connection.QueryResult {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if strings.TrimSpace(selection) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
if abs, err := filepath.Abs(selection); err == nil {
|
||||
selection = abs
|
||||
@@ -192,7 +392,7 @@ func (a *App) SelectDatabaseFile(currentPath string, driverType string) connecti
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if strings.TrimSpace(selection) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
if abs, err := filepath.Abs(selection); err == nil {
|
||||
selection = abs
|
||||
@@ -203,7 +403,7 @@ func (a *App) SelectDatabaseFile(currentPath string, driverType string) connecti
|
||||
// PreviewImportFile 解析导入文件,返回字段列表、总行数、前 5 行预览数据
|
||||
func (a *App) PreviewImportFile(filePath string) connection.QueryResult {
|
||||
if filePath == "" {
|
||||
return connection.QueryResult{Success: false, Message: "File path required"}
|
||||
return connection.QueryResult{Success: false, Message: "文件路径不能为空"}
|
||||
}
|
||||
|
||||
rows, columns, err := parseImportFile(filePath)
|
||||
@@ -243,7 +443,7 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
|
||||
}
|
||||
|
||||
if selection == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
// 返回文件路径供前端预览
|
||||
@@ -492,7 +692,7 @@ func (a *App) ImportDataWithProgress(config connection.ConnectionConfig, dbName,
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
return connection.QueryResult{Success: true, Message: "No data to import"}
|
||||
return connection.QueryResult{Success: true, Message: "无可导入数据"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
@@ -584,7 +784,7 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
|
||||
})
|
||||
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
@@ -616,7 +816,7 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return connection.QueryResult{Success: true, Message: "导出完成"}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(runConfig.Type, tableName))
|
||||
@@ -632,10 +832,10 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
|
||||
}
|
||||
defer f.Close()
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
return connection.QueryResult{Success: false, Message: "写入失败:" + err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return connection.QueryResult{Success: true, Message: "导出完成"}
|
||||
}
|
||||
|
||||
func (a *App) ExportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeData bool) connection.QueryResult {
|
||||
@@ -648,7 +848,7 @@ func (a *App) ExportTablesDataSQL(config connection.ConnectionConfig, dbName str
|
||||
|
||||
func (a *App) exportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeSchema bool, includeData bool) connection.QueryResult {
|
||||
if !includeSchema && !includeData {
|
||||
return connection.QueryResult{Success: false, Message: "invalid export mode"}
|
||||
return connection.QueryResult{Success: false, Message: "无效的导出模式"}
|
||||
}
|
||||
|
||||
safeDbName := strings.TrimSpace(dbName)
|
||||
@@ -671,7 +871,7 @@ func (a *App) exportTablesSQL(config connection.ConnectionConfig, dbName string,
|
||||
DefaultFilename: defaultFilename,
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
@@ -717,13 +917,13 @@ func (a *App) exportTablesSQL(config connection.ConnectionConfig, dbName string,
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return connection.QueryResult{Success: true, Message: "导出完成"}
|
||||
}
|
||||
|
||||
func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName string, includeData bool) connection.QueryResult {
|
||||
safeDbName := strings.TrimSpace(dbName)
|
||||
if safeDbName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "dbName required"}
|
||||
return connection.QueryResult{Success: false, Message: "数据库名称不能为空"}
|
||||
}
|
||||
suffix := "schema"
|
||||
if includeData {
|
||||
@@ -735,7 +935,7 @@ func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName strin
|
||||
DefaultFilename: fmt.Sprintf("%s_%s.sql", safeDbName, suffix),
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
@@ -772,7 +972,92 @@ func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName strin
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return connection.QueryResult{Success: true, Message: "导出完成"}
|
||||
}
|
||||
|
||||
// TruncateTables 清空指定表的数据(针对 MySQL 使用 TRUNCATE,MongoDB 使用 delete,否则使用 DELETE)。
|
||||
// 注意:MySQL 的 TRUNCATE TABLE 是 DDL 操作,无法事务回滚;批量清空为逐表执行,
|
||||
// 如果中途失败,已清空的表无法恢复。错误结果会附带已执行的 SQL 列表供排查。
|
||||
func (a *App) TruncateTables(config connection.ConnectionConfig, dbName string, tableNames []string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
// 参数校验
|
||||
if len(tableNames) == 0 {
|
||||
return connection.QueryResult{Success: false, Message: "未指定要清空的表"}
|
||||
}
|
||||
|
||||
objects := make([]string, 0, len(tableNames))
|
||||
seen := make(map[string]struct{}, len(tableNames))
|
||||
for _, t := range tableNames {
|
||||
tt := strings.TrimSpace(t)
|
||||
if tt == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[tt]; ok {
|
||||
continue
|
||||
}
|
||||
seen[tt] = struct{}{}
|
||||
objects = append(objects, tt)
|
||||
}
|
||||
|
||||
if len(objects) == 0 {
|
||||
return connection.QueryResult{Success: false, Message: "未指定要清空的表"}
|
||||
}
|
||||
const maxBatchSize = 200
|
||||
if len(objects) > maxBatchSize {
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("单次最多清空 %d 张表,当前选中 %d 张", maxBatchSize, len(objects))}
|
||||
}
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
// 审计日志:记录清空操作的发起
|
||||
logger.Warnf("TruncateTables 开始:%s db=%s tables=%v(共 %d 张)", formatConnSummary(runConfig), dbName, objects, len(objects))
|
||||
|
||||
dbType := strings.ToLower(strings.TrimSpace(runConfig.Type))
|
||||
var executedSQLs []string
|
||||
for i, objectName := range objects {
|
||||
var sql string
|
||||
if dbType == "mysql" || dbType == "mariadb" {
|
||||
sql = fmt.Sprintf("TRUNCATE TABLE %s", quoteQualifiedIdentByType(runConfig.Type, objectName))
|
||||
} else if dbType == "mongodb" {
|
||||
// MongoDB 使用 delete 命令清空集合中的所有文档
|
||||
// deletes 的 limit 为 0 表示删除所有匹配的文档
|
||||
sql = fmt.Sprintf(`{"delete":"%s","deletes":[{"q":{},"limit":0}]}`, objectName)
|
||||
} else {
|
||||
sql = fmt.Sprintf("DELETE FROM %s", quoteQualifiedIdentByType(runConfig.Type, objectName))
|
||||
}
|
||||
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
logger.Warnf("TruncateTables 第 %d/%d 张表失败:%s table=%s err=%v(已成功清空 %d 张)", i+1, len(objects), formatConnSummary(runConfig), objectName, err, len(executedSQLs))
|
||||
errMsg := fmt.Sprintf("清空 %s 失败: %v", objectName, err)
|
||||
if len(executedSQLs) > 0 {
|
||||
errMsg += fmt.Sprintf("(注意:前 %d 张表已清空且无法恢复)", len(executedSQLs))
|
||||
}
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: errMsg,
|
||||
Data: map[string]interface{}{
|
||||
"executedSQLs": executedSQLs,
|
||||
"count": len(executedSQLs),
|
||||
},
|
||||
}
|
||||
}
|
||||
executedSQLs = append(executedSQLs, sql)
|
||||
}
|
||||
|
||||
logger.Warnf("TruncateTables 完成:%s db=%s 共清空 %d 张表", formatConnSummary(runConfig), dbName, len(executedSQLs))
|
||||
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
Message: "清空成功",
|
||||
Data: map[string]interface{}{
|
||||
"executedSQLs": executedSQLs,
|
||||
"count": len(executedSQLs),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func quoteIdentByType(dbType string, ident string) string {
|
||||
@@ -1471,7 +1756,7 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
|
||||
if err != nil || filename == "" {
|
||||
logger.Infof("ExportData 已取消或未选择文件:err=%v", err)
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
logger.Infof("ExportData 选定文件:%s", filename)
|
||||
|
||||
@@ -1482,11 +1767,11 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
defer f.Close()
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
logger.Warnf("ExportData 写入失败:file=%s err=%v", filename, err)
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
return connection.QueryResult{Success: false, Message: "写入失败:" + err.Error()}
|
||||
}
|
||||
|
||||
logger.Infof("ExportData 完成:file=%s rows=%d", filename, len(data))
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return connection.QueryResult{Success: true, Message: "导出完成"}
|
||||
}
|
||||
|
||||
// ExportQuery exports by executing the provided SELECT query on backend side.
|
||||
@@ -1494,7 +1779,7 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, query string, defaultName string, format string) connection.QueryResult {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return connection.QueryResult{Success: false, Message: "query required"}
|
||||
return connection.QueryResult{Success: false, Message: "查询语句不能为空"}
|
||||
}
|
||||
|
||||
if defaultName == "" {
|
||||
@@ -1507,7 +1792,7 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
logger.Infof("ExportQuery 已取消或未选择文件:err=%v", err)
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
logger.Infof("ExportQuery 开始:type=%s db=%s format=%s file=%s sql=%q", strings.TrimSpace(config.Type), strings.TrimSpace(dbName), strings.ToLower(strings.TrimSpace(format)), filename, sqlSnippet(query))
|
||||
|
||||
@@ -1520,7 +1805,7 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que
|
||||
query = sanitizeSQLForPgLike(runConfig.Type, query)
|
||||
lowerQuery := strings.ToLower(strings.TrimSpace(query))
|
||||
if !(strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "with")) {
|
||||
return connection.QueryResult{Success: false, Message: "Only SELECT/WITH queries are supported"}
|
||||
return connection.QueryResult{Success: false, Message: "仅支持 SELECT/WITH 查询导出"}
|
||||
}
|
||||
|
||||
data, columns, err := queryDataForExport(dbInst, runConfig, query)
|
||||
@@ -1537,11 +1822,11 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que
|
||||
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
logger.Warnf("ExportQuery 写入失败:file=%s err=%v", filename, err)
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
return connection.QueryResult{Success: false, Message: "写入失败:" + err.Error()}
|
||||
}
|
||||
|
||||
logger.Infof("ExportQuery 完成:file=%s rows=%d cols=%d", filename, len(data), len(columns))
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return connection.QueryResult{Success: true, Message: "导出完成"}
|
||||
}
|
||||
|
||||
func queryDataForExport(dbInst db.Database, config connection.ConnectionConfig, query string) ([]map[string]interface{}, []string, error) {
|
||||
|
||||
@@ -23,12 +23,20 @@ var (
|
||||
|
||||
// getRedisClient gets or creates a Redis client from cache
|
||||
func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisClient, error) {
|
||||
key := getRedisClientCacheKey(config)
|
||||
effectiveConfig := applyGlobalProxyToConnection(config)
|
||||
connectConfig, proxyErr := resolveDialConfigWithProxy(effectiveConfig)
|
||||
if proxyErr != nil {
|
||||
wrapped := wrapConnectError(effectiveConfig, proxyErr)
|
||||
logger.Error(wrapped, "Redis 代理准备失败:%s", formatRedisConnSummary(effectiveConfig))
|
||||
return nil, wrapped
|
||||
}
|
||||
|
||||
key := getRedisClientCacheKey(connectConfig)
|
||||
shortKey := key
|
||||
if len(shortKey) > 12 {
|
||||
shortKey = shortKey[:12]
|
||||
}
|
||||
logger.Infof("获取 Redis 连接:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
|
||||
logger.Infof("获取 Redis 连接:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey)
|
||||
|
||||
redisCacheMu.Lock()
|
||||
defer redisCacheMu.Unlock()
|
||||
@@ -47,21 +55,20 @@ func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisCli
|
||||
|
||||
logger.Infof("创建 Redis 客户端实例:缓存Key=%s", shortKey)
|
||||
client := redis.NewRedisClient()
|
||||
if err := client.Connect(config); err != nil {
|
||||
logger.Error(err, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
|
||||
return nil, err
|
||||
if err := client.Connect(connectConfig); err != nil {
|
||||
wrapped := wrapConnectError(effectiveConfig, err)
|
||||
logger.Error(wrapped, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey)
|
||||
return nil, wrapped
|
||||
}
|
||||
|
||||
redisCache[key] = client
|
||||
logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
|
||||
logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func getRedisClientCacheKey(config connection.ConnectionConfig) string {
|
||||
if !config.UseSSH {
|
||||
config.SSH = connection.SSHConfig{}
|
||||
}
|
||||
b, _ := json.Marshal(config)
|
||||
normalized := normalizeCacheKeyConfig(config)
|
||||
b, _ := json.Marshal(normalized)
|
||||
sum := sha256.Sum256(b)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -91,6 +98,26 @@ func formatRedisConnSummary(config connection.ConnectionConfig) string {
|
||||
b.WriteString(" 用户=")
|
||||
b.WriteString(config.SSH.User)
|
||||
}
|
||||
if config.UseProxy {
|
||||
b.WriteString(" 代理=")
|
||||
b.WriteString(strings.ToLower(strings.TrimSpace(config.Proxy.Type)))
|
||||
b.WriteString("://")
|
||||
b.WriteString(config.Proxy.Host)
|
||||
b.WriteString(":")
|
||||
b.WriteString(strconv.Itoa(config.Proxy.Port))
|
||||
if strings.TrimSpace(config.Proxy.User) != "" {
|
||||
b.WriteString(" 代理认证=已配置")
|
||||
}
|
||||
}
|
||||
if config.UseHTTPTunnel {
|
||||
b.WriteString(" HTTP隧道=")
|
||||
b.WriteString(strings.TrimSpace(config.HTTPTunnel.Host))
|
||||
b.WriteString(":")
|
||||
b.WriteString(strconv.Itoa(config.HTTPTunnel.Port))
|
||||
if strings.TrimSpace(config.HTTPTunnel.User) != "" {
|
||||
b.WriteString(" HTTP隧道认证=已配置")
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -426,6 +453,23 @@ func (a *App) RedisRenameKey(config connection.ConnectionConfig, oldKey, newKey
|
||||
return connection.QueryResult{Success: true, Message: "重命名成功"}
|
||||
}
|
||||
|
||||
// RedisKeyExists checks whether a key already exists
|
||||
func (a *App) RedisKeyExists(config connection.ConnectionConfig, key string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
exists, err := client.KeyExists(key)
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisKeyExists 检查失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: map[string]bool{"exists": exists}}
|
||||
}
|
||||
|
||||
// RedisDeleteHashField deletes fields from a hash
|
||||
func (a *App) RedisDeleteHashField(config connection.ConnectionConfig, key string, fields []string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
|
||||
@@ -5,6 +5,66 @@ import (
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func leadingSQLKeyword(query string) string {
|
||||
text := strings.TrimSpace(query)
|
||||
for len(text) > 0 {
|
||||
trimmed := strings.TrimLeft(text, " \t\r\n")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
text = trimmed
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(text, "--"):
|
||||
if idx := strings.IndexByte(text, '\n'); idx >= 0 {
|
||||
text = text[idx+1:]
|
||||
continue
|
||||
}
|
||||
return ""
|
||||
case strings.HasPrefix(text, "#"):
|
||||
if idx := strings.IndexByte(text, '\n'); idx >= 0 {
|
||||
text = text[idx+1:]
|
||||
continue
|
||||
}
|
||||
return ""
|
||||
case strings.HasPrefix(text, "/*"):
|
||||
if idx := strings.Index(text, "*/"); idx >= 0 {
|
||||
text = text[idx+2:]
|
||||
continue
|
||||
}
|
||||
return ""
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
for i, r := range text {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
|
||||
continue
|
||||
}
|
||||
if i == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(text[:i])
|
||||
}
|
||||
return strings.ToLower(text)
|
||||
}
|
||||
|
||||
func isReadOnlySQLQuery(dbType string, query string) bool {
|
||||
if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
|
||||
return true
|
||||
}
|
||||
|
||||
switch leadingSQLKeyword(query) {
|
||||
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeSQLForPgLike(dbType string, query string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
|
||||
175
internal/app/sql_split.go
Normal file
175
internal/app/sql_split.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package app
|
||||
|
||||
import "strings"
|
||||
|
||||
// splitSQLStatements 按分号拆分 SQL 文本为独立语句。
|
||||
// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)和
|
||||
// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,避免在这些上下文中错误拆分。
|
||||
// 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。
|
||||
func splitSQLStatements(sql string) []string {
|
||||
text := strings.ReplaceAll(sql, "\r\n", "\n")
|
||||
var statements []string
|
||||
|
||||
var cur strings.Builder
|
||||
inSingle := false
|
||||
inDouble := false
|
||||
inBacktick := false
|
||||
escaped := false
|
||||
inLineComment := false
|
||||
inBlockComment := false
|
||||
var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$
|
||||
|
||||
push := func() {
|
||||
s := strings.TrimSpace(cur.String())
|
||||
if s != "" {
|
||||
statements = append(statements, s)
|
||||
}
|
||||
cur.Reset()
|
||||
}
|
||||
|
||||
for i := 0; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
next := byte(0)
|
||||
if i+1 < len(text) {
|
||||
next = text[i+1]
|
||||
}
|
||||
|
||||
// 行注释
|
||||
if inLineComment {
|
||||
if ch == '\n' {
|
||||
inLineComment = false
|
||||
}
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 块注释
|
||||
if inBlockComment {
|
||||
cur.WriteByte(ch)
|
||||
if ch == '*' && next == '/' {
|
||||
cur.WriteByte('/')
|
||||
i++
|
||||
inBlockComment = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Dollar-quoting
|
||||
if dollarTag != "" {
|
||||
if strings.HasPrefix(text[i:], dollarTag) {
|
||||
cur.WriteString(dollarTag)
|
||||
i += len(dollarTag) - 1
|
||||
dollarTag = ""
|
||||
} else {
|
||||
cur.WriteByte(ch)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 转义字符(反斜杠转义,MySQL 风格)
|
||||
if escaped {
|
||||
escaped = false
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if (inSingle || inDouble) && ch == '\\' {
|
||||
escaped = true
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 字符串开闭
|
||||
if !inDouble && !inBacktick && ch == '\'' {
|
||||
if inSingle && next == '\'' {
|
||||
// SQL 标准转义:两个连续单引号 '' 表示字面量引号,保持在引号内
|
||||
cur.WriteByte(ch)
|
||||
cur.WriteByte(next)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
inSingle = !inSingle
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if !inSingle && !inBacktick && ch == '"' {
|
||||
inDouble = !inDouble
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if !inSingle && !inDouble && ch == '`' {
|
||||
inBacktick = !inBacktick
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 在引号/反引号内部不做任何判断
|
||||
if inSingle || inDouble || inBacktick {
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 行注释开始
|
||||
if ch == '-' && next == '-' {
|
||||
inLineComment = true
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if ch == '#' {
|
||||
inLineComment = true
|
||||
cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 块注释开始
|
||||
if ch == '/' && next == '*' {
|
||||
inBlockComment = true
|
||||
cur.WriteString("/*")
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Dollar-quoting 开始
|
||||
if ch == '$' {
|
||||
if tag := parseSQLDollarTag(text[i:]); tag != "" {
|
||||
dollarTag = tag
|
||||
cur.WriteString(tag)
|
||||
i += len(tag) - 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 分号分隔(支持全角分号";")
|
||||
if ch == ';' {
|
||||
push()
|
||||
continue
|
||||
}
|
||||
// 全角分号 UTF-8 序列: 0xEF 0xBC 0x9B
|
||||
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
|
||||
push()
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
|
||||
cur.WriteByte(ch)
|
||||
}
|
||||
|
||||
push()
|
||||
return statements
|
||||
}
|
||||
|
||||
// parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。
|
||||
func parseSQLDollarTag(s string) string {
|
||||
if len(s) < 2 || s[0] != '$' {
|
||||
return ""
|
||||
}
|
||||
for i := 1; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c == '$' {
|
||||
return s[:i+1]
|
||||
}
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
209
internal/app/sql_split_stream.go
Normal file
209
internal/app/sql_split_stream.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// sqlStreamSplitter 是一个流式 SQL 语句拆分器,适用于处理大文件。
|
||||
// 调用方通过 Feed(chunk) 逐块喂入数据,通过 Flush() 获取最后一条残余语句。
|
||||
// 内部维护与 splitSQLStatements 完全一致的状态机逻辑。
|
||||
type sqlStreamSplitter struct {
|
||||
cur strings.Builder
|
||||
inSingle bool
|
||||
inDouble bool
|
||||
inBacktick bool
|
||||
escaped bool
|
||||
inLineComment bool
|
||||
inBlockComment bool
|
||||
dollarTag string
|
||||
}
|
||||
|
||||
// Feed 将一个 chunk 喂入拆分器,返回在此 chunk 中完成的 SQL 语句列表。
|
||||
func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
|
||||
var statements []string
|
||||
text := string(chunk)
|
||||
|
||||
for i := 0; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
next := byte(0)
|
||||
if i+1 < len(text) {
|
||||
next = text[i+1]
|
||||
}
|
||||
|
||||
// 行注释
|
||||
if s.inLineComment {
|
||||
if ch == '\n' {
|
||||
s.inLineComment = false
|
||||
}
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 块注释
|
||||
if s.inBlockComment {
|
||||
s.cur.WriteByte(ch)
|
||||
if ch == '*' && next == '/' {
|
||||
s.cur.WriteByte('/')
|
||||
i++
|
||||
s.inBlockComment = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Dollar-quoting
|
||||
if s.dollarTag != "" {
|
||||
if strings.HasPrefix(text[i:], s.dollarTag) {
|
||||
s.cur.WriteString(s.dollarTag)
|
||||
i += len(s.dollarTag) - 1
|
||||
s.dollarTag = ""
|
||||
} else {
|
||||
s.cur.WriteByte(ch)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 转义字符
|
||||
if s.escaped {
|
||||
s.escaped = false
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if (s.inSingle || s.inDouble) && ch == '\\' {
|
||||
s.escaped = true
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 字符串开闭
|
||||
if !s.inDouble && !s.inBacktick && ch == '\'' {
|
||||
if s.inSingle && next == '\'' {
|
||||
// SQL 标准转义:两个连续单引号
|
||||
s.cur.WriteByte(ch)
|
||||
s.cur.WriteByte(next)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
s.inSingle = !s.inSingle
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if !s.inSingle && !s.inBacktick && ch == '"' {
|
||||
s.inDouble = !s.inDouble
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if !s.inSingle && !s.inDouble && ch == '`' {
|
||||
s.inBacktick = !s.inBacktick
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 在引号/反引号内部不做任何判断
|
||||
if s.inSingle || s.inDouble || s.inBacktick {
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 行注释开始
|
||||
if ch == '-' && next == '-' {
|
||||
s.inLineComment = true
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if ch == '#' {
|
||||
s.inLineComment = true
|
||||
s.cur.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// 块注释开始
|
||||
if ch == '/' && next == '*' {
|
||||
s.inBlockComment = true
|
||||
s.cur.WriteString("/*")
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Dollar-quoting 开始
|
||||
if ch == '$' {
|
||||
if tag := parseSQLDollarTag(text[i:]); tag != "" {
|
||||
s.dollarTag = tag
|
||||
s.cur.WriteString(tag)
|
||||
i += len(tag) - 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 分号分隔
|
||||
if ch == ';' {
|
||||
stmt := strings.TrimSpace(s.cur.String())
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
s.cur.Reset()
|
||||
continue
|
||||
}
|
||||
// 全角分号
|
||||
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
|
||||
stmt := strings.TrimSpace(s.cur.String())
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
s.cur.Reset()
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
|
||||
s.cur.WriteByte(ch)
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// Flush 返回缓冲区中剩余的不完整语句(文件结束时调用)。
|
||||
func (s *sqlStreamSplitter) Flush() string {
|
||||
stmt := strings.TrimSpace(s.cur.String())
|
||||
s.cur.Reset()
|
||||
return stmt
|
||||
}
|
||||
|
||||
// streamSQLFile 从 reader 中流式读取 SQL 并逐条回调。
|
||||
// onStatement 返回 error 时停止读取并返回该 error。
|
||||
// 返回总处理语句数和可能的错误。
|
||||
func streamSQLFile(reader io.Reader, onStatement func(index int, stmt string) error) (int, error) {
|
||||
splitter := &sqlStreamSplitter{}
|
||||
scanner := bufio.NewScanner(reader)
|
||||
// 设置最大 token 为 4MB,处理超长单行
|
||||
const maxLineSize = 4 * 1024 * 1024
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
count := 0
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
// 保持换行符,因为行注释依赖 \n 来结束
|
||||
lineWithNewline := append(line, '\n')
|
||||
stmts := splitter.Feed(lineWithNewline)
|
||||
for _, stmt := range stmts {
|
||||
if err := onStatement(count, stmt); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return count, err
|
||||
}
|
||||
|
||||
// 处理文件末尾不以分号结尾的最后一条语句
|
||||
if last := splitter.Flush(); last != "" {
|
||||
if err := onStatement(count, last); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
113
internal/app/sql_split_test.go
Normal file
113
internal/app/sql_split_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitSQLStatements_BasicSplit(t *testing.T) {
|
||||
input := "SELECT 1; SELECT 2; SELECT 3"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT 1", "SELECT 2", "SELECT 3"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_QuotedSemicolon(t *testing.T) {
|
||||
input := `SELECT 'hello;world'; SELECT 2`
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{`SELECT 'hello;world'`, "SELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_LineComment(t *testing.T) {
|
||||
input := "SELECT 1; -- this is a comment;\nSELECT 2"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT 1", "-- this is a comment;\nSELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_BlockComment(t *testing.T) {
|
||||
input := "SELECT /* ; */ 1; SELECT 2"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT /* ; */ 1", "SELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_EmptyInput(t *testing.T) {
|
||||
got := splitSQLStatements("")
|
||||
if len(got) != 0 {
|
||||
t.Errorf("splitSQLStatements(\"\") = %v, want empty slice", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_SingleStatement(t *testing.T) {
|
||||
input := "SELECT * FROM users WHERE id = 1"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT * FROM users WHERE id = 1"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_DollarQuoting(t *testing.T) {
|
||||
input := "SELECT $tag$hello;world$tag$; SELECT 2"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT $tag$hello;world$tag$", "SELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_FullWidthSemicolon(t *testing.T) {
|
||||
input := "SELECT 1;SELECT 2"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT 1", "SELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_Backtick(t *testing.T) {
|
||||
input := "SELECT `col;name` FROM t; SELECT 2"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT `col;name` FROM t", "SELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_TrailingSemicolon(t *testing.T) {
|
||||
input := "SELECT 1; SELECT 2;"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT 1", "SELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_SQLEscapedQuote(t *testing.T) {
|
||||
input := "SELECT 'it''s a test'; SELECT 2"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"SELECT 'it''s a test'", "SELECT 2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_SQLEscapedQuoteMultiple(t *testing.T) {
|
||||
input := "INSERT INTO t VALUES ('O''Brien', 'it''s OK'); SELECT 1"
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{"INSERT INTO t VALUES ('O''Brien', 'it''s OK')", "SELECT 1"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package connection
|
||||
|
||||
// SSHConfig holds SSH connection details
|
||||
// SSHConfig 存储 SSH 隧道连接配置。
|
||||
type SSHConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
@@ -9,7 +9,7 @@ type SSHConfig struct {
|
||||
KeyPath string `json:"keyPath"`
|
||||
}
|
||||
|
||||
// ProxyConfig holds proxy connection details
|
||||
// ProxyConfig 存储代理连接配置。
|
||||
type ProxyConfig struct {
|
||||
Type string `json:"type"` // socks5 | http
|
||||
Host string `json:"host"`
|
||||
@@ -18,42 +18,58 @@ type ProxyConfig struct {
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
// ConnectionConfig holds database connection details including SSH
|
||||
type ConnectionConfig struct {
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection
|
||||
Database string `json:"database"`
|
||||
UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch
|
||||
SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable
|
||||
SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng)
|
||||
SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng)
|
||||
UseSSH bool `json:"useSSH"`
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
UseProxy bool `json:"useProxy,omitempty"`
|
||||
Proxy ProxyConfig `json:"proxy,omitempty"`
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
|
||||
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
|
||||
Topology string `json:"topology,omitempty"` // single | replica | cluster
|
||||
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user
|
||||
MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password
|
||||
ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name
|
||||
AuthSource string `json:"authSource,omitempty"` // MongoDB authSource
|
||||
ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference
|
||||
MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme
|
||||
MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism
|
||||
MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user
|
||||
MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password
|
||||
// HTTPTunnelConfig 存储 HTTP CONNECT 隧道配置。
|
||||
type HTTPTunnelConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResult is the standard response format for Wails methods
|
||||
// ConnectionConfig 存储数据库连接的完整配置,包括 SSH、代理、SSL 等网络层设置。
|
||||
type ConnectionConfig struct {
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection
|
||||
Database string `json:"database"`
|
||||
UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch
|
||||
SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable
|
||||
SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng)
|
||||
SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng)
|
||||
UseSSH bool `json:"useSSH"`
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
UseProxy bool `json:"useProxy,omitempty"`
|
||||
Proxy ProxyConfig `json:"proxy,omitempty"`
|
||||
UseHTTPTunnel bool `json:"useHttpTunnel,omitempty"`
|
||||
HTTPTunnel HTTPTunnelConfig `json:"httpTunnel,omitempty"`
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
|
||||
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
|
||||
Topology string `json:"topology,omitempty"` // single | replica | cluster
|
||||
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user
|
||||
MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password
|
||||
ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name
|
||||
AuthSource string `json:"authSource,omitempty"` // MongoDB authSource
|
||||
ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference
|
||||
MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme
|
||||
MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism
|
||||
MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user
|
||||
MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password
|
||||
}
|
||||
|
||||
// ResultSetData 表示一个查询结果集(行 + 列名),用于多结果集场景。
|
||||
type ResultSetData struct {
|
||||
Rows []map[string]interface{} `json:"rows"`
|
||||
Columns []string `json:"columns"`
|
||||
}
|
||||
|
||||
// QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。
|
||||
type QueryResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
@@ -62,7 +78,7 @@ type QueryResult struct {
|
||||
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
|
||||
}
|
||||
|
||||
// ColumnDefinition represents a table column
|
||||
// ColumnDefinition 描述表的一个列定义。
|
||||
type ColumnDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
@@ -73,16 +89,17 @@ type ColumnDefinition struct {
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
|
||||
// IndexDefinition represents a table index
|
||||
// IndexDefinition 描述表的一个索引定义。
|
||||
type IndexDefinition struct {
|
||||
Name string `json:"name"`
|
||||
ColumnName string `json:"columnName"`
|
||||
NonUnique int `json:"nonUnique"`
|
||||
SeqInIndex int `json:"seqInIndex"`
|
||||
IndexType string `json:"indexType"`
|
||||
SubPart int `json:"subPart,omitempty"`
|
||||
}
|
||||
|
||||
// ForeignKeyDefinition represents a foreign key
|
||||
// ForeignKeyDefinition 描述表的一个外键定义。
|
||||
type ForeignKeyDefinition struct {
|
||||
Name string `json:"name"`
|
||||
ColumnName string `json:"columnName"`
|
||||
@@ -91,7 +108,7 @@ type ForeignKeyDefinition struct {
|
||||
ConstraintName string `json:"constraintName"`
|
||||
}
|
||||
|
||||
// TriggerDefinition represents a trigger
|
||||
// TriggerDefinition 描述表的一个触发器定义。
|
||||
type TriggerDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Timing string `json:"timing"` // BEFORE/AFTER
|
||||
@@ -99,26 +116,27 @@ type TriggerDefinition struct {
|
||||
Statement string `json:"statement"`
|
||||
}
|
||||
|
||||
// ColumnDefinitionWithTable represents a column with its table name (for search/autocomplete)
|
||||
// ColumnDefinitionWithTable 带有表名标识的列定义,用于跨表搜索和 SQL 自动补全。
|
||||
type ColumnDefinitionWithTable struct {
|
||||
TableName string `json:"tableName"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// UpdateRow represents a row update with keys (WHERE) and values (SET)
|
||||
// UpdateRow 表示一行更新操作,Keys 为 WHERE 条件,Values 为 SET 值。
|
||||
type UpdateRow struct {
|
||||
Keys map[string]interface{} `json:"keys"`
|
||||
Values map[string]interface{} `json:"values"`
|
||||
}
|
||||
|
||||
// ChangeSet represents a batch of changes
|
||||
// ChangeSet 表示一组批量变更,包含新增、修改和删除操作。
|
||||
type ChangeSet struct {
|
||||
Inserts []map[string]interface{} `json:"inserts"`
|
||||
Updates []UpdateRow `json:"updates"`
|
||||
Deletes []map[string]interface{} `json:"deletes"`
|
||||
}
|
||||
|
||||
// MongoMemberInfo 描述 MongoDB 副本集成员的信息。
|
||||
type MongoMemberInfo struct {
|
||||
Host string `json:"host"`
|
||||
Role string `json:"role"`
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -107,7 +108,9 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
|
||||
if readTimeout < minClickHouseReadTimeout {
|
||||
readTimeout = minClickHouseReadTimeout
|
||||
}
|
||||
protocol := detectClickHouseProtocol(config)
|
||||
opts := &clickhouse.Options{
|
||||
Protocol: protocol,
|
||||
Addr: []string{
|
||||
net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
},
|
||||
@@ -125,6 +128,46 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
|
||||
return opts
|
||||
}
|
||||
|
||||
func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Protocol {
|
||||
uriText := strings.ToLower(strings.TrimSpace(config.URI))
|
||||
if strings.HasPrefix(uriText, "http://") || strings.HasPrefix(uriText, "https://") {
|
||||
return clickhouse.HTTP
|
||||
}
|
||||
if config.Port == 8123 || config.Port == 8443 {
|
||||
return clickhouse.HTTP
|
||||
}
|
||||
return clickhouse.Native
|
||||
}
|
||||
|
||||
func isClickHouseProtocolMismatch(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(text, "unexpected packet [72]") ||
|
||||
(strings.Contains(text, "unexpected packet") && strings.Contains(text, "handshake")) ||
|
||||
strings.Contains(text, "http response to https client") ||
|
||||
strings.Contains(text, "malformed http response")
|
||||
}
|
||||
|
||||
func withClickHouseProtocol(config connection.ConnectionConfig, protocol clickhouse.Protocol) connection.ConnectionConfig {
|
||||
next := config
|
||||
switch protocol {
|
||||
case clickhouse.HTTP:
|
||||
if next.Port == 0 {
|
||||
next.Port = 8123
|
||||
}
|
||||
default:
|
||||
if next.Port == 0 {
|
||||
next.Port = defaultClickHousePort
|
||||
}
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
if supported, reason := DriverRuntimeSupportStatus("clickhouse"); !supported {
|
||||
if strings.TrimSpace(reason) == "" {
|
||||
@@ -176,23 +219,41 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(attempt))
|
||||
if err := c.Ping(); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
primaryProtocol := detectClickHouseProtocol(attempt)
|
||||
protocols := []clickhouse.Protocol{primaryProtocol}
|
||||
if primaryProtocol == clickhouse.Native {
|
||||
protocols = append(protocols, clickhouse.HTTP)
|
||||
} else {
|
||||
protocols = append(protocols, clickhouse.Native)
|
||||
}
|
||||
|
||||
for pIdx, protocol := range protocols {
|
||||
protocolConfig := withClickHouseProtocol(attempt, protocol)
|
||||
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(protocolConfig))
|
||||
if err := c.Ping(); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %v", idx+1, protocol.String(), err))
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
if pIdx == 0 && !isClickHouseProtocolMismatch(err) {
|
||||
// 首次连接不是协议误配特征,避免无谓重试次协议。
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
if idx > 0 {
|
||||
logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
if pIdx > 0 {
|
||||
logger.Warnf("ClickHouse 已自动切换连接协议为 %s(常见于 8123/8443 HTTP 端口)", protocol.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
return fmt.Errorf("连接建立后验证失败(可检查 ClickHouse 端口与协议是否匹配:Native=9000/9440,HTTP=8123/8443):%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) Close() error {
|
||||
@@ -210,7 +271,7 @@ func (c *ClickHouseDB) Close() error {
|
||||
|
||||
func (c *ClickHouseDB) Ping() error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := c.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -223,7 +284,7 @@ func (c *ClickHouseDB) Ping() error {
|
||||
|
||||
func (c *ClickHouseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := c.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -235,7 +296,7 @@ func (c *ClickHouseDB) QueryContext(ctx context.Context, query string) ([]map[st
|
||||
|
||||
func (c *ClickHouseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := c.conn.Query(query)
|
||||
if err != nil {
|
||||
@@ -247,7 +308,7 @@ func (c *ClickHouseDB) Query(query string) ([]map[string]interface{}, []string,
|
||||
|
||||
func (c *ClickHouseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if c.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := c.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -258,7 +319,7 @@ func (c *ClickHouseDB) ExecContext(ctx context.Context, query string) (int64, er
|
||||
|
||||
func (c *ClickHouseDB) Exec(query string) (int64, error) {
|
||||
if c.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := c.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -343,7 +404,7 @@ func (c *ClickHouseDB) GetCreateStatement(dbName, tableName string) (string, err
|
||||
return "", err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
row := data[0]
|
||||
if val, ok := getClickHouseValueFromRow(row, "statement", "create_statement", "sql", "query"); ok {
|
||||
@@ -366,7 +427,7 @@ func (c *ClickHouseDB) GetCreateStatement(dbName, tableName string) (string, err
|
||||
if longest != "" {
|
||||
return longest, nil
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
@@ -521,7 +582,7 @@ func (c *ClickHouseDB) GetTriggers(dbName, tableName string) ([]connection.Trigg
|
||||
func (c *ClickHouseDB) resolveDatabaseAndTable(dbName, tableName string) (string, string, error) {
|
||||
rawTable := strings.TrimSpace(tableName)
|
||||
if rawTable == "" {
|
||||
return "", "", fmt.Errorf("table name required")
|
||||
return "", "", fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
resolvedDB := strings.TrimSpace(dbName)
|
||||
@@ -542,7 +603,7 @@ func (c *ClickHouseDB) resolveDatabaseAndTable(dbName, tableName string) (string
|
||||
resolvedDB = defaultClickHouseDatabase
|
||||
}
|
||||
if resolvedTable == "" {
|
||||
return "", "", fmt.Errorf("table name required")
|
||||
return "", "", fmt.Errorf("表名不能为空")
|
||||
}
|
||||
return resolvedDB, resolvedTable, nil
|
||||
}
|
||||
@@ -618,3 +679,134 @@ func isClickHouseTruthy(value interface{}) bool {
|
||||
return normalized == "1" || normalized == "true" || normalized == "yes" || normalized == "y"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
database, table, err := c.resolveDatabaseAndTable(c.database, tableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
qualifiedTable := fmt.Sprintf("%s.%s", quoteClickHouseIdentifier(database), quoteClickHouseIdentifier(table))
|
||||
|
||||
for _, pk := range changes.Deletes {
|
||||
whereExpr := buildClickHouseWhereClause(pk)
|
||||
if whereExpr == "" {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("ALTER TABLE %s DELETE WHERE %s", qualifiedTable, whereExpr)
|
||||
if _, err := c.conn.Exec(query); err != nil {
|
||||
return fmt.Errorf("delete error: %v; sql=%s", err, query)
|
||||
}
|
||||
}
|
||||
|
||||
for _, update := range changes.Updates {
|
||||
setExpr := buildClickHouseAssignments(update.Values)
|
||||
whereExpr := buildClickHouseWhereClause(update.Keys)
|
||||
if setExpr == "" || whereExpr == "" {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("ALTER TABLE %s UPDATE %s WHERE %s", qualifiedTable, setExpr, whereExpr)
|
||||
if _, err := c.conn.Exec(query); err != nil {
|
||||
return fmt.Errorf("update error: %v; sql=%s", err, query)
|
||||
}
|
||||
}
|
||||
|
||||
for _, row := range changes.Inserts {
|
||||
query, err := buildClickHouseInsertSQL(qualifiedTable, row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if query == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := c.conn.Exec(query); err != nil {
|
||||
return fmt.Errorf("插入失败:%v; sql=%s", err, query)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildClickHouseInsertSQL(qualifiedTable string, row map[string]interface{}) (string, error) {
|
||||
if len(row) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
cols := make([]string, 0, len(row))
|
||||
for k := range row {
|
||||
if strings.TrimSpace(k) == "" {
|
||||
continue
|
||||
}
|
||||
cols = append(cols, k)
|
||||
}
|
||||
if len(cols) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
sort.Strings(cols)
|
||||
quotedCols := make([]string, 0, len(cols))
|
||||
values := make([]string, 0, len(cols))
|
||||
for _, col := range cols {
|
||||
quotedCols = append(quotedCols, quoteClickHouseIdentifier(col))
|
||||
values = append(values, clickHouseLiteral(row[col]))
|
||||
}
|
||||
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", ")), nil
|
||||
}
|
||||
|
||||
func buildClickHouseAssignments(values map[string]interface{}) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
cols := make([]string, 0, len(values))
|
||||
for k := range values {
|
||||
if strings.TrimSpace(k) == "" {
|
||||
continue
|
||||
}
|
||||
cols = append(cols, k)
|
||||
}
|
||||
sort.Strings(cols)
|
||||
parts := make([]string, 0, len(cols))
|
||||
for _, col := range cols {
|
||||
parts = append(parts, fmt.Sprintf("%s = %s", quoteClickHouseIdentifier(col), clickHouseLiteral(values[col])))
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
func buildClickHouseWhereClause(keys map[string]interface{}) string {
|
||||
if len(keys) == 0 {
|
||||
return ""
|
||||
}
|
||||
cols := make([]string, 0, len(keys))
|
||||
for k := range keys {
|
||||
if strings.TrimSpace(k) == "" {
|
||||
continue
|
||||
}
|
||||
cols = append(cols, k)
|
||||
}
|
||||
sort.Strings(cols)
|
||||
parts := make([]string, 0, len(cols))
|
||||
for _, col := range cols {
|
||||
parts = append(parts, fmt.Sprintf("%s = %s", quoteClickHouseIdentifier(col), clickHouseLiteral(keys[col])))
|
||||
}
|
||||
return strings.Join(parts, " AND ")
|
||||
}
|
||||
|
||||
func clickHouseLiteral(value interface{}) string {
|
||||
switch val := value.(type) {
|
||||
case nil:
|
||||
return "NULL"
|
||||
case bool:
|
||||
if val {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
return fmt.Sprintf("%v", val)
|
||||
case time.Time:
|
||||
return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05"))
|
||||
case []byte:
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(string(val), "'", "''"))
|
||||
default:
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprintf("%v", val), "'", "''"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (c *CustomDB) Close() error {
|
||||
|
||||
func (c *CustomDB) Ping() error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := c.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -60,7 +60,7 @@ func (c *CustomDB) Ping() error {
|
||||
|
||||
func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := c.conn.QueryContext(ctx, query)
|
||||
@@ -74,7 +74,7 @@ func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string
|
||||
|
||||
func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := c.conn.Query(query)
|
||||
@@ -87,7 +87,7 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
|
||||
func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if c.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := c.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -98,7 +98,7 @@ func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (c *CustomDB) Exec(query string) (int64, error) {
|
||||
if c.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := c.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -249,7 +249,7 @@ func (c *CustomDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
|
||||
|
||||
func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := c.conn.Begin()
|
||||
@@ -321,7 +321,7 @@ func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,12 +349,12 @@ func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ func (d *DamengDB) Close() error {
|
||||
|
||||
func (d *DamengDB) Ping() error {
|
||||
if d.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := d.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -156,7 +156,7 @@ func (d *DamengDB) Ping() error {
|
||||
|
||||
func (d *DamengDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := d.conn.QueryContext(ctx, query)
|
||||
@@ -170,7 +170,7 @@ func (d *DamengDB) QueryContext(ctx context.Context, query string) ([]map[string
|
||||
|
||||
func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := d.conn.Query(query)
|
||||
@@ -183,7 +183,7 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
|
||||
func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if d.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := d.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -194,7 +194,7 @@ func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (d *DamengDB) Exec(query string) (int64, error) {
|
||||
if d.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := d.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -204,22 +204,9 @@ func (d *DamengDB) Exec(query string) (int64, error) {
|
||||
}
|
||||
|
||||
func (d *DamengDB) GetDatabases() ([]string, error) {
|
||||
// DM: List Users/Schemas
|
||||
data, _, err := d.Query("SELECT username FROM dba_users")
|
||||
if err != nil {
|
||||
// Fallback if dba_users not accessible
|
||||
data, _, err = d.Query("SELECT username FROM all_users")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["USERNAME"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
// 达梦在本项目中将 schema/owner 作为“数据库”展示口径。
|
||||
// 先查当前 schema / 当前用户,再聚合可见用户与 owner,避免权限受限时返回空列表。
|
||||
return collectDamengDatabaseNames(d.Query)
|
||||
}
|
||||
|
||||
func (d *DamengDB) GetTables(dbName string) ([]string, error) {
|
||||
@@ -273,7 +260,7 @@ func (d *DamengDB) GetCreateStatement(dbName, tableName string) (string, error)
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (d *DamengDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
@@ -403,7 +390,7 @@ func (d *DamengDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
|
||||
|
||||
func (d *DamengDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if d.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := d.conn.Begin()
|
||||
@@ -451,7 +438,7 @@ func (d *DamengDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -479,12 +466,12 @@ func (d *DamengDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -508,7 +495,7 @@ func (d *DamengDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
91
internal/db/dameng_metadata.go
Normal file
91
internal/db/dameng_metadata.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var damengDatabaseQueries = []string{
|
||||
"SELECT SYS_CONTEXT('USERENV', 'CURRENT_SCHEMA') AS DATABASE_NAME FROM DUAL",
|
||||
"SELECT SYS_CONTEXT('USERENV', 'CURRENT_USER') AS DATABASE_NAME FROM DUAL",
|
||||
"SELECT USERNAME AS DATABASE_NAME FROM USER_USERS",
|
||||
"SELECT USERNAME AS DATABASE_NAME FROM ALL_USERS ORDER BY USERNAME",
|
||||
"SELECT USERNAME AS DATABASE_NAME FROM DBA_USERS ORDER BY USERNAME",
|
||||
"SELECT USERNAME AS DATABASE_NAME FROM SYS.DBA_USERS ORDER BY USERNAME",
|
||||
"SELECT DISTINCT OWNER AS DATABASE_NAME FROM ALL_OBJECTS ORDER BY OWNER",
|
||||
"SELECT DISTINCT OWNER AS DATABASE_NAME FROM ALL_TABLES ORDER BY OWNER",
|
||||
}
|
||||
|
||||
type damengQueryFunc func(query string) ([]map[string]interface{}, []string, error)
|
||||
|
||||
func collectDamengDatabaseNames(query damengQueryFunc) ([]string, error) {
|
||||
seen := make(map[string]struct{})
|
||||
dbs := make([]string, 0, 64)
|
||||
var lastErr error
|
||||
|
||||
for _, q := range damengDatabaseQueries {
|
||||
data, _, err := query(q)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
for _, row := range data {
|
||||
name := getDamengRowString(row,
|
||||
"DATABASE_NAME",
|
||||
"USERNAME",
|
||||
"OWNER",
|
||||
"SCHEMA_NAME",
|
||||
"CURRENT_SCHEMA",
|
||||
"CURRENT_USER",
|
||||
)
|
||||
if name == "" {
|
||||
for _, v := range row {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if text == "" || strings.EqualFold(text, "<nil>") {
|
||||
continue
|
||||
}
|
||||
name = text
|
||||
break
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToUpper(name)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
dbs = append(dbs, name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dbs) == 0 && lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
sort.Slice(dbs, func(i, j int) bool {
|
||||
return strings.ToUpper(dbs[i]) < strings.ToUpper(dbs[j])
|
||||
})
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func getDamengRowString(row map[string]interface{}, keys ...string) string {
|
||||
if len(row) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, key := range keys {
|
||||
for k, v := range row {
|
||||
if !strings.EqualFold(strings.TrimSpace(k), strings.TrimSpace(key)) {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if text == "" || strings.EqualFold(text, "<nil>") {
|
||||
return ""
|
||||
}
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
73
internal/db/dameng_metadata_test.go
Normal file
73
internal/db/dameng_metadata_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCollectDamengDatabaseNames_UsesCurrentSchemaFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := collectDamengDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
|
||||
switch query {
|
||||
case damengDatabaseQueries[0]:
|
||||
return []map[string]interface{}{{"DATABASE_NAME": "APP_SCHEMA"}}, nil, nil
|
||||
case damengDatabaseQueries[1]:
|
||||
return []map[string]interface{}{{"DATABASE_NAME": "app_schema"}}, nil, nil
|
||||
default:
|
||||
return nil, nil, errors.New("permission denied")
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("collectDamengDatabaseNames 返回错误: %v", err)
|
||||
}
|
||||
|
||||
want := []string{"APP_SCHEMA"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected database names, got=%v want=%v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDamengDatabaseNames_CollectsOwnersWhenVisible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := collectDamengDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
|
||||
switch query {
|
||||
case damengDatabaseQueries[0], damengDatabaseQueries[1], damengDatabaseQueries[2], damengDatabaseQueries[3], damengDatabaseQueries[4], damengDatabaseQueries[5]:
|
||||
return []map[string]interface{}{}, nil, nil
|
||||
case damengDatabaseQueries[6]:
|
||||
return []map[string]interface{}{{"OWNER": "BIZ"}, {"OWNER": "audit"}}, nil, nil
|
||||
case damengDatabaseQueries[7]:
|
||||
return []map[string]interface{}{{"OWNER": "BIZ"}}, nil, nil
|
||||
default:
|
||||
return nil, nil, nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("collectDamengDatabaseNames 返回错误: %v", err)
|
||||
}
|
||||
|
||||
want := []string{"audit", "BIZ"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected database names, got=%v want=%v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDamengDatabaseNames_ReturnsErrorWhenNoNameResolved(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectErr := errors.New("last query failed")
|
||||
got, err := collectDamengDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
|
||||
if query == damengDatabaseQueries[len(damengDatabaseQueries)-1] {
|
||||
return nil, nil, expectErr
|
||||
}
|
||||
return nil, nil, errors.New("permission denied")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("期望返回错误,实际 got=%v", got)
|
||||
}
|
||||
if !errors.Is(err, expectErr) {
|
||||
t.Fatalf("错误不符合预期: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -2,27 +2,58 @@ package db
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Database 定义了统一的数据源访问接口。
|
||||
// 所有数据库驱动(MySQL、PostgreSQL、Oracle 等)均需实现此接口。
|
||||
// 方法调用方可通过 NewDatabase 工厂函数获取对应驱动的实例。
|
||||
type Database interface {
|
||||
// Connect 根据连接配置建立数据库连接。
|
||||
Connect(config connection.ConnectionConfig) error
|
||||
// Close 关闭数据库连接并释放底层资源。
|
||||
Close() error
|
||||
// Ping 测试连接是否仍然可用。
|
||||
Ping() error
|
||||
// Query 执行查询语句,返回结果行(列名→值映射)和列名列表。
|
||||
Query(query string) ([]map[string]interface{}, []string, error)
|
||||
// Exec 执行非查询语句(INSERT/UPDATE/DELETE 等),返回受影响行数。
|
||||
Exec(query string) (int64, error)
|
||||
// GetDatabases 返回当前连接可访问的数据库列表。
|
||||
GetDatabases() ([]string, error)
|
||||
// GetTables 返回指定数据库下的表列表。
|
||||
GetTables(dbName string) ([]string, error)
|
||||
// GetCreateStatement 返回指定表的建表 DDL 语句。
|
||||
GetCreateStatement(dbName, tableName string) (string, error)
|
||||
// GetColumns 返回指定表的列定义列表。
|
||||
GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error)
|
||||
// GetAllColumns 返回指定数据库下所有表的列定义(含表名标识)。
|
||||
GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error)
|
||||
// GetIndexes 返回指定表的索引定义列表。
|
||||
GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error)
|
||||
// GetForeignKeys 返回指定表的外键定义列表。
|
||||
GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error)
|
||||
// GetTriggers 返回指定表的触发器定义列表。
|
||||
GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error)
|
||||
}
|
||||
|
||||
// MultiResultQuerier 是可选接口,支持多结果集的驱动实现此接口。
|
||||
// 执行可能包含多条 SQL 语句的查询,返回所有结果集。
|
||||
type MultiResultQuerier interface {
|
||||
QueryMulti(query string) ([]connection.ResultSetData, error)
|
||||
}
|
||||
|
||||
// MultiResultQuerierContext 是带 context 的多结果集查询接口。
|
||||
type MultiResultQuerierContext interface {
|
||||
QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error)
|
||||
}
|
||||
|
||||
// BatchApplier 定义了批量变更提交接口。
|
||||
// 支持批量编辑的驱动实现此接口,用于一次性提交前端 DataGrid 中的增删改操作。
|
||||
type BatchApplier interface {
|
||||
// ApplyChanges 将一组变更(新增、修改、删除)批量提交到指定表。
|
||||
ApplyChanges(tableName string, changes connection.ChangeSet) error
|
||||
}
|
||||
|
||||
@@ -72,7 +103,9 @@ func normalizeDatabaseType(dbType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// Factory
|
||||
// NewDatabase 根据数据库类型创建对应的 Database 实例。
|
||||
// dbType 为数据库类型标识(如 "mysql"、"postgres"、"oracle" 等),大小写不敏感。
|
||||
// 如果指定类型未注册,返回错误。
|
||||
func NewDatabase(dbType string) (Database, error) {
|
||||
normalized := normalizeDatabaseType(dbType)
|
||||
if normalized == "" {
|
||||
@@ -80,7 +113,7 @@ func NewDatabase(dbType string) (Database, error) {
|
||||
}
|
||||
factory, ok := databaseFactories[normalized]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||
return nil, fmt.Errorf("不支持的数据库类型:%s", dbType)
|
||||
}
|
||||
return factory(), nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -135,26 +134,26 @@ func collectDirosAddresses(config connection.ConnectionConfig) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *DirosDB) getDSN(config connection.ConnectionConfig) string {
|
||||
func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
database := config.Database
|
||||
protocol := "tcp"
|
||||
address := normalizeMySQLAddress(config.Host, config.Port)
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
protocol = netName
|
||||
address = normalizeMySQLAddress(config.Host, config.Port)
|
||||
} else {
|
||||
logger.Warnf("注册 Doris SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
|
||||
), nil
|
||||
}
|
||||
|
||||
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
@@ -192,7 +191,11 @@ func (d *DirosDB) Connect(config connection.ConnectionConfig) error {
|
||||
candidateConfig.Port = port
|
||||
candidateConfig.User, candidateConfig.Password = resolveDirosCredential(runConfig, index)
|
||||
|
||||
dsn := d.getDSN(candidateConfig)
|
||||
dsn, err := d.getDSN(candidateConfig)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
|
||||
continue
|
||||
}
|
||||
db, err := sql.Open(dirosDriverName, dsn)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
|
||||
|
||||
74
internal/db/driver_agent_binary_check.go
Normal file
74
internal/db/driver_agent_binary_check.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"debug/pe"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
peMachineI386 uint16 = 0x014c
|
||||
peMachineAmd64 uint16 = 0x8664
|
||||
peMachineArm64 uint16 = 0xaa64
|
||||
)
|
||||
|
||||
func windowsMachineLabel(machine uint16) string {
|
||||
switch machine {
|
||||
case peMachineI386:
|
||||
return "windows-386"
|
||||
case peMachineAmd64:
|
||||
return "windows-amd64"
|
||||
case peMachineArm64:
|
||||
return "windows-arm64"
|
||||
default:
|
||||
return fmt.Sprintf("windows-unknown(0x%04x)", machine)
|
||||
}
|
||||
}
|
||||
|
||||
func expectedWindowsMachineForGoArch(goarch string) (uint16, string, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(goarch)) {
|
||||
case "386":
|
||||
return peMachineI386, "windows-386", true
|
||||
case "amd64":
|
||||
return peMachineAmd64, "windows-amd64", true
|
||||
case "arm64":
|
||||
return peMachineArm64, "windows-arm64", true
|
||||
default:
|
||||
return 0, "", false
|
||||
}
|
||||
}
|
||||
|
||||
func validateWindowsExecutableMachine(pathText string) error {
|
||||
file, err := pe.Open(pathText)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法识别为有效的 Windows 可执行文件:%w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
expectedMachine, expectedLabel, ok := expectedWindowsMachineForGoArch(runtime.GOARCH)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
actualMachine := file.FileHeader.Machine
|
||||
if actualMachine != expectedMachine {
|
||||
return fmt.Errorf("可执行文件架构不兼容(文件=%s,当前进程=%s)", windowsMachineLabel(actualMachine), expectedLabel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateOptionalDriverAgentExecutable 校验可选驱动代理二进制是否可在当前进程中执行。
|
||||
// 当前主要用于 Windows 下的 PE 架构兼容性校验,避免升级后复用到错误架构的旧代理。
|
||||
func ValidateOptionalDriverAgentExecutable(driverType string, executablePath string) error {
|
||||
pathText := strings.TrimSpace(executablePath)
|
||||
if pathText == "" {
|
||||
return fmt.Errorf("%s 驱动代理路径为空", driverDisplayName(driverType))
|
||||
}
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil
|
||||
}
|
||||
if err := validateWindowsExecutableMachine(pathText); err != nil {
|
||||
return fmt.Errorf("%s 驱动代理不可用:%w", driverDisplayName(driverType), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// coreBuiltinDrivers 是始终内置可用的核心驱动,无需额外安装即可使用。
|
||||
var coreBuiltinDrivers = map[string]struct{}{
|
||||
"mysql": {},
|
||||
"redis": {},
|
||||
@@ -91,6 +92,8 @@ func driverDisplayName(driverType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// IsOptionalGoDriver 返回指定驱动类型是否为可选的纯 Go 驱动。
|
||||
// 可选驱动需要用户在驱动管理界面点击“安装启用”后才能使用。
|
||||
func IsOptionalGoDriver(driverType string) bool {
|
||||
_, ok := optionalGoDrivers[normalizeRuntimeDriverType(driverType)]
|
||||
return ok
|
||||
@@ -100,6 +103,7 @@ func IsOptionalGoDriverBuildIncluded(driverType string) bool {
|
||||
return optionalGoDriverBuildIncluded(normalizeRuntimeDriverType(driverType))
|
||||
}
|
||||
|
||||
// IsBuiltinDriver 返回指定驱动类型是否为核心内置驱动(始终可用,无需安装)。
|
||||
func IsBuiltinDriver(driverType string) bool {
|
||||
_, ok := coreBuiltinDrivers[normalizeRuntimeDriverType(driverType)]
|
||||
return ok
|
||||
@@ -146,6 +150,8 @@ func currentExternalDriverDownloadDirectory() string {
|
||||
return defaultExternalDriverDownloadDirectory()
|
||||
}
|
||||
|
||||
// SetExternalDriverDownloadDirectory 设置可选驱动的下载存储目录。
|
||||
// 如果路径解析失败,会回退到默认目录(~/.gonavi/drivers)。
|
||||
func SetExternalDriverDownloadDirectory(downloadDir string) {
|
||||
root, err := resolveExternalDriverRoot(downloadDir)
|
||||
if err != nil {
|
||||
@@ -194,6 +200,9 @@ func optionalGoDriverRuntimeReady(driverType string) (bool, string) {
|
||||
if statErr != nil || info.IsDir() {
|
||||
return false, fmt.Sprintf("%s 驱动代理缺失,请在驱动管理中重新安装启用", driverDisplayName(normalized))
|
||||
}
|
||||
if validateErr := ValidateOptionalDriverAgentExecutable(normalized, executablePath); validateErr != nil {
|
||||
return false, fmt.Sprintf("%s;请在驱动管理中重新安装启用", validateErr.Error())
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
|
||||
@@ -65,11 +65,22 @@ func TestManagedDriverRequiresInstallMarker(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("解析 mariadb 代理路径失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(executablePath, []byte("placeholder"), 0o755); err != nil {
|
||||
t.Fatalf("写入 mariadb 代理占位文件失败: %v", err)
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
_ = os.Chmod(executablePath, 0o644)
|
||||
selfPath, selfErr := os.Executable()
|
||||
if selfErr != nil {
|
||||
t.Fatalf("获取测试进程路径失败: %v", selfErr)
|
||||
}
|
||||
content, readErr := os.ReadFile(selfPath)
|
||||
if readErr != nil {
|
||||
t.Fatalf("读取测试进程失败: %v", readErr)
|
||||
}
|
||||
if err := os.WriteFile(executablePath, content, 0o755); err != nil {
|
||||
t.Fatalf("写入 mariadb 代理占位可执行文件失败: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := os.WriteFile(executablePath, []byte("placeholder"), 0o755); err != nil {
|
||||
t.Fatalf("写入 mariadb 代理占位文件失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
supported, reason := DriverRuntimeSupportStatus("mariadb")
|
||||
|
||||
@@ -55,7 +55,7 @@ func (d *DuckDB) Close() error {
|
||||
|
||||
func (d *DuckDB) Ping() error {
|
||||
if d.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := d.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -68,7 +68,7 @@ func (d *DuckDB) Ping() error {
|
||||
|
||||
func (d *DuckDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := d.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -80,7 +80,7 @@ func (d *DuckDB) QueryContext(ctx context.Context, query string) ([]map[string]i
|
||||
|
||||
func (d *DuckDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := d.conn.Query(query)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func (d *DuckDB) Query(query string) ([]map[string]interface{}, []string, error)
|
||||
|
||||
func (d *DuckDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if d.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := d.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -103,7 +103,7 @@ func (d *DuckDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
|
||||
func (d *DuckDB) Exec(query string) (int64, error) {
|
||||
if d.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := d.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -174,7 +174,7 @@ ORDER BY table_schema, table_name`
|
||||
func (d *DuckDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
schema, pureTable := normalizeDuckDBSchemaAndTable(dbName, tableName)
|
||||
if pureTable == "" {
|
||||
return "", fmt.Errorf("table name required")
|
||||
return "", fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
escapedTable := escapeDuckDBLiteral(pureTable)
|
||||
@@ -204,13 +204,13 @@ func (d *DuckDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (d *DuckDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
schema, pureTable := normalizeDuckDBSchemaAndTable(dbName, tableName)
|
||||
if pureTable == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@@ -303,7 +303,7 @@ func (d *DuckDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefi
|
||||
|
||||
func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if d.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := d.conn.Begin()
|
||||
@@ -346,7 +346,7 @@ func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) er
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,12 +367,12 @@ func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) er
|
||||
args = append(args, v)
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -392,7 +392,7 @@ func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) er
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ func (h *HighGoDB) Close() error {
|
||||
|
||||
func (h *HighGoDB) Ping() error {
|
||||
if h.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := h.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -138,7 +138,7 @@ func (h *HighGoDB) Ping() error {
|
||||
|
||||
func (h *HighGoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if h.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := h.conn.QueryContext(ctx, query)
|
||||
@@ -152,7 +152,7 @@ func (h *HighGoDB) QueryContext(ctx context.Context, query string) ([]map[string
|
||||
|
||||
func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if h.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := h.conn.Query(query)
|
||||
@@ -165,7 +165,7 @@ func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
|
||||
func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if h.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := h.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -176,7 +176,7 @@ func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (h *HighGoDB) Exec(query string) (int64, error) {
|
||||
if h.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := h.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -232,7 +232,7 @@ func (h *HighGoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -302,7 +302,7 @@ func (h *HighGoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -407,7 +407,7 @@ func (h *HighGoDB) GetForeignKeys(dbName, tableName string) ([]connection.Foreig
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -467,7 +467,7 @@ func (h *HighGoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -531,7 +531,7 @@ ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
func (h *HighGoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if h.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := h.conn.Begin()
|
||||
@@ -579,7 +579,7 @@ func (h *HighGoDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -607,12 +607,12 @@ func (h *HighGoDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,7 +636,7 @@ func (h *HighGoDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
206
internal/db/kingbase_identifier_utils.go
Normal file
206
internal/db/kingbase_identifier_utils.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package db
|
||||
|
||||
import "strings"
|
||||
|
||||
func normalizeKingbaseIdentCommon(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 兼容被多次 JSON 序列化后的转义引号:
|
||||
// \\\"schema\\\" -> \"schema\" -> "schema"
|
||||
for i := 0; i < 8; i++ {
|
||||
next := strings.TrimSpace(value)
|
||||
next = strings.ReplaceAll(next, `\\\"`, `\"`)
|
||||
next = strings.ReplaceAll(next, `\"`, `"`)
|
||||
if next == value {
|
||||
break
|
||||
}
|
||||
value = next
|
||||
}
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
stripWrapperOnce := func(text string) string {
|
||||
t := strings.TrimSpace(text)
|
||||
if strings.HasPrefix(t, `\`) && len(t) > 1 {
|
||||
t = strings.TrimSpace(strings.TrimPrefix(t, `\`))
|
||||
}
|
||||
if strings.HasSuffix(t, `\`) && len(t) > 1 {
|
||||
t = strings.TrimSpace(strings.TrimSuffix(t, `\`))
|
||||
}
|
||||
if len(t) >= 4 && strings.HasPrefix(t, `\"`) && strings.HasSuffix(t, `\"`) {
|
||||
return strings.TrimSpace(t[2 : len(t)-2])
|
||||
}
|
||||
if len(t) >= 2 && strings.HasPrefix(t, `"`) && strings.HasSuffix(t, `"`) {
|
||||
return strings.TrimSpace(t[1 : len(t)-1])
|
||||
}
|
||||
if len(t) >= 2 && strings.HasPrefix(t, "`") && strings.HasSuffix(t, "`") {
|
||||
return strings.TrimSpace(t[1 : len(t)-1])
|
||||
}
|
||||
if len(t) >= 2 && strings.HasPrefix(t, "[") && strings.HasSuffix(t, "]") {
|
||||
return strings.TrimSpace(t[1 : len(t)-1])
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
next := stripWrapperOnce(value)
|
||||
if next == value {
|
||||
break
|
||||
}
|
||||
value = next
|
||||
}
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// 兼容错误的二次引用与残留反斜杠。
|
||||
value = strings.ReplaceAll(value, `\"`, `"`)
|
||||
value = strings.ReplaceAll(value, `""`, "")
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
next := strings.TrimSpace(value)
|
||||
changed := false
|
||||
if strings.HasPrefix(next, `\`) && len(next) > 1 {
|
||||
next = strings.TrimSpace(strings.TrimPrefix(next, `\`))
|
||||
changed = true
|
||||
}
|
||||
if strings.HasSuffix(next, `\`) && len(next) > 1 {
|
||||
next = strings.TrimSpace(strings.TrimSuffix(next, `\`))
|
||||
changed = true
|
||||
}
|
||||
if !changed || next == value {
|
||||
break
|
||||
}
|
||||
value = next
|
||||
}
|
||||
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func splitKingbaseQualifiedNameCommon(raw string) (schema string, table string) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
sep := findKingbaseQualifiedSeparator(text)
|
||||
if sep < 0 {
|
||||
return "", normalizeKingbaseIdentCommon(text)
|
||||
}
|
||||
|
||||
schemaPart := normalizeKingbaseIdentCommon(text[:sep])
|
||||
tablePart := normalizeKingbaseIdentCommon(text[sep+1:])
|
||||
|
||||
if tablePart == "" {
|
||||
if schemaPart == "" {
|
||||
return "", normalizeKingbaseIdentCommon(text)
|
||||
}
|
||||
return "", schemaPart
|
||||
}
|
||||
if schemaPart == "" {
|
||||
return "", tablePart
|
||||
}
|
||||
return schemaPart, tablePart
|
||||
}
|
||||
|
||||
func findKingbaseQualifiedSeparator(raw string) int {
|
||||
inDouble := false
|
||||
inBacktick := false
|
||||
inBracket := false
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(raw); i++ {
|
||||
ch := raw[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if inDouble {
|
||||
if ch == '"' {
|
||||
// SQL 双引号转义:"" 代表字面量 "
|
||||
if i+1 < len(raw) && raw[i+1] == '"' {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
inDouble = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if inBacktick {
|
||||
if ch == '`' {
|
||||
inBacktick = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if inBracket {
|
||||
if ch == ']' {
|
||||
inBracket = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '"':
|
||||
inDouble = true
|
||||
case '`':
|
||||
inBacktick = true
|
||||
case '[':
|
||||
inBracket = true
|
||||
case '.':
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// buildKingbaseSearchPathCommon 统一构建 Kingbase search_path。
|
||||
// 返回 search_path SQL 片段和规范化后的 schema 列表(用于调试/扩展)。
|
||||
func buildKingbaseSearchPathCommon(rawSchemas []string) (string, []string) {
|
||||
if len(rawSchemas) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(rawSchemas)+1)
|
||||
quotedParts := make([]string, 0, len(rawSchemas)+1)
|
||||
normalizedSchemas := make([]string, 0, len(rawSchemas)+1)
|
||||
|
||||
appendSchema := func(raw string) {
|
||||
cleaned := normalizeKingbaseIdentCommon(raw)
|
||||
if cleaned == "" {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(cleaned, "public") {
|
||||
cleaned = "public"
|
||||
}
|
||||
key := strings.ToLower(cleaned)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
normalizedSchemas = append(normalizedSchemas, cleaned)
|
||||
escaped := strings.ReplaceAll(cleaned, `"`, `""`)
|
||||
quotedParts = append(quotedParts, `"`+escaped+`"`)
|
||||
}
|
||||
|
||||
for _, raw := range rawSchemas {
|
||||
appendSchema(raw)
|
||||
}
|
||||
if _, ok := seen["public"]; !ok {
|
||||
appendSchema("public")
|
||||
}
|
||||
|
||||
if len(quotedParts) == 0 {
|
||||
return "", normalizedSchemas
|
||||
}
|
||||
return strings.Join(quotedParts, ", "), normalizedSchemas
|
||||
}
|
||||
92
internal/db/kingbase_identifier_utils_test.go
Normal file
92
internal/db/kingbase_identifier_utils_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeKingbaseIdentCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "plain", in: "ldf_server", want: "ldf_server"},
|
||||
{name: "quoted", in: `"ldf_server"`, want: "ldf_server"},
|
||||
{name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"},
|
||||
{name: "double escaped quoted", in: `\\\"ldf_server\\\"`, want: "ldf_server"},
|
||||
{name: "double quoted", in: `""ldf_server""`, want: "ldf_server"},
|
||||
{name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"},
|
||||
{name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := normalizeKingbaseIdentCommon(tt.in); got != tt.want {
|
||||
t.Fatalf("normalizeKingbaseIdentCommon(%q)=%q,want=%q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitKingbaseQualifiedNameCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
wantSchema string
|
||||
wantTable string
|
||||
}{
|
||||
{name: "plain", in: "ldf_server.andon_events", wantSchema: "ldf_server", wantTable: "andon_events"},
|
||||
{name: "quoted", in: `"ldf_server"."andon_events"`, wantSchema: "ldf_server", wantTable: "andon_events"},
|
||||
{name: "escaped quoted", in: `\"ldf_server\".\"andon_events\"`, wantSchema: "ldf_server", wantTable: "andon_events"},
|
||||
{name: "double escaped quoted", in: `\\\"ldf_server\\\".\\\"andon_events\\\"`, wantSchema: "ldf_server", wantTable: "andon_events"},
|
||||
{name: "space around dot", in: ` "ldf_server" . "andon_events" `, wantSchema: "ldf_server", wantTable: "andon_events"},
|
||||
{name: "table only", in: "andon_events", wantSchema: "", wantTable: "andon_events"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotSchema, gotTable := splitKingbaseQualifiedNameCommon(tt.in)
|
||||
if gotSchema != tt.wantSchema || gotTable != tt.wantTable {
|
||||
t.Fatalf("splitKingbaseQualifiedNameCommon(%q)=(%q,%q),want=(%q,%q)", tt.in, gotSchema, gotTable, tt.wantSchema, tt.wantTable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKingbaseSearchPathCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in []string
|
||||
want string
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "normal schemas",
|
||||
in: []string{"ldf_server", "public"},
|
||||
want: `"ldf_server", "public"`,
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "quoted and escaped schemas should not be double quoted",
|
||||
in: []string{`"ldf_server"`, `""bcs_barcode""`, `\"public\"`},
|
||||
want: `"ldf_server", "bcs_barcode", "public"`,
|
||||
wantLen: 3,
|
||||
},
|
||||
{
|
||||
name: "dedupe ignoring case and keep public fallback",
|
||||
in: []string{"LDF_SERVER", "ldf_server", "PUBLIC"},
|
||||
want: `"LDF_SERVER", "public"`,
|
||||
wantLen: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, parts := buildKingbaseSearchPathCommon(tt.in)
|
||||
if got != tt.want {
|
||||
t.Fatalf("buildKingbaseSearchPathCommon(%v)=%q,want=%q", tt.in, got, tt.want)
|
||||
}
|
||||
if len(parts) != tt.wantLen {
|
||||
t.Fatalf("buildKingbaseSearchPathCommon(%v) parts=%v, len=%d, wantLen=%d", tt.in, parts, len(parts), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -136,11 +137,83 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
if idx > 0 {
|
||||
logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
|
||||
// 获取 schema 列表以重构带有 search_path 的连接池
|
||||
searchPathStr := k.getSearchPathStr()
|
||||
if searchPathStr != "" {
|
||||
// 将 search_path 参数拼入 DSN
|
||||
finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr)
|
||||
if finalDB, err := sql.Open("kingbase", finalDSN); err == nil {
|
||||
k.pingTimeout = getConnectTimeout(attempt)
|
||||
finalDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// 临时将 k.conn 指向 finalDB 来做 ping 测试
|
||||
oldConn := k.conn
|
||||
k.conn = finalDB
|
||||
if err := k.Ping(); err == nil {
|
||||
// 成功使用带 search_path 的连接池
|
||||
_ = oldConn.Close()
|
||||
logger.Infof("人大金仓已配置连接级 search_path:%s", searchPathStr)
|
||||
} else {
|
||||
_ = finalDB.Close()
|
||||
k.conn = oldConn
|
||||
}
|
||||
}
|
||||
}
|
||||
if searchPathStr != "" {
|
||||
timeout := k.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
if _, err := k.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathStr)); err != nil {
|
||||
logger.Warnf("人大金仓显式设置 search_path 失败:%v", err)
|
||||
} else {
|
||||
logger.Infof("人大金仓已设置默认 search_path:%s", searchPathStr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
// getSearchPathStr 查询当前数据库中所有用户 schema,配置 DSN 的 search_path。
|
||||
// KingBase 默认 search_path 为 "$user", public,对于自定义 schema 下的表不可见。
|
||||
func (k *KingbaseDB) getSearchPathStr() string {
|
||||
if k.conn == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
query := `SELECT nspname FROM pg_namespace
|
||||
WHERE nspname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND nspname NOT LIKE 'pg_%'
|
||||
ORDER BY nspname`
|
||||
|
||||
rows, err := k.conn.Query(query)
|
||||
if err != nil {
|
||||
logger.Warnf("人大金仓查询用户 schema 失败,跳过 search_path 设置:%v", err)
|
||||
return ""
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var rawSchemas []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
continue
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name != "" {
|
||||
rawSchemas = append(rawSchemas, name)
|
||||
}
|
||||
}
|
||||
|
||||
searchPath, _ := buildKingbaseSearchPathCommon(rawSchemas)
|
||||
return searchPath
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if k.forwarder != nil {
|
||||
@@ -159,7 +232,7 @@ func (k *KingbaseDB) Close() error {
|
||||
|
||||
func (k *KingbaseDB) Ping() error {
|
||||
if k.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := k.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -172,7 +245,7 @@ func (k *KingbaseDB) Ping() error {
|
||||
|
||||
func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := k.conn.QueryContext(ctx, query)
|
||||
@@ -186,7 +259,7 @@ func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[stri
|
||||
|
||||
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := k.conn.Query(query)
|
||||
@@ -199,7 +272,7 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
|
||||
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if k.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := k.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -210,7 +283,7 @@ func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, erro
|
||||
|
||||
func (k *KingbaseDB) Exec(query string) (int64, error) {
|
||||
if k.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := k.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -294,7 +367,7 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
@@ -305,10 +378,30 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '%s' AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, esc(schema), esc(table))
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
col_description(a.attrelid, a.attnum) AS comment,
|
||||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||||
LEFT JOIN (
|
||||
SELECT i.indrelid, a3.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||||
WHERE i.indisprimary
|
||||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||||
WHERE c.relkind IN ('r', 'p')
|
||||
AND n.nspname = '%s'
|
||||
AND c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum`, esc(schema), esc(table))
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -321,11 +414,21 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if row["column_default"] != nil {
|
||||
def := fmt.Sprintf("%v", row["column_default"])
|
||||
col.Default = &def
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||||
col.Extra = "auto_increment"
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := row["comment"]; ok && v != nil {
|
||||
col.Comment = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
@@ -337,7 +440,7 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
// 转义函数
|
||||
@@ -347,10 +450,30 @@ func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection
|
||||
}
|
||||
|
||||
// 使用 current_schema() 获取当前schema
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = current_schema() AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, esc(table))
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
col_description(a.attrelid, a.attnum) AS comment,
|
||||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||||
LEFT JOIN (
|
||||
SELECT i.indrelid, a3.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||||
WHERE i.indisprimary
|
||||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||||
WHERE c.relkind IN ('r', 'p')
|
||||
AND n.nspname = current_schema()
|
||||
AND c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum`, esc(table))
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -363,11 +486,21 @@ func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if row["column_default"] != nil {
|
||||
def := fmt.Sprintf("%v", row["column_default"])
|
||||
col.Default = &def
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||||
col.Extra = "auto_increment"
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := row["comment"]; ok && v != nil {
|
||||
col.Comment = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
@@ -391,7 +524,7 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
@@ -489,7 +622,7 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
@@ -571,7 +704,7 @@ func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.Trigger
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
@@ -614,7 +747,7 @@ func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.Trigger
|
||||
|
||||
func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if k.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := k.conn.Begin()
|
||||
@@ -623,28 +756,16 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
quoteIdent := func(name string) string {
|
||||
n := strings.TrimSpace(name)
|
||||
n = strings.Trim(n, "\"")
|
||||
n = strings.ReplaceAll(n, "\"", "\"\"")
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
schema := ""
|
||||
table := strings.TrimSpace(tableName)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
schema, table := splitKingbaseQualifiedTable(tableName)
|
||||
if table == "" {
|
||||
return fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
qualifiedTable := ""
|
||||
if schema != "" {
|
||||
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
|
||||
qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table))
|
||||
} else {
|
||||
qualifiedTable = quoteIdent(table)
|
||||
qualifiedTable = quoteKingbaseIdent(table)
|
||||
}
|
||||
|
||||
// 1. Deletes
|
||||
@@ -654,7 +775,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
idx := 0
|
||||
for k, v := range pk {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
@@ -662,7 +783,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("delete error: %v; sql=%s", err, query)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -674,7 +795,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
|
||||
for k, v := range update.Values {
|
||||
idx++
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
@@ -685,17 +806,17 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
var wheres []string
|
||||
for k, v := range update.Keys {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("update error: %v; sql=%s", err, query)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -708,7 +829,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
|
||||
for k, v := range row {
|
||||
idx++
|
||||
cols = append(cols, quoteIdent(k))
|
||||
cols = append(cols, quoteKingbaseIdent(k))
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
@@ -719,13 +840,73 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v; sql=%s", err, query)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func normalizeKingbaseIdentifier(raw string) string {
|
||||
return normalizeKingbaseIdentCommon(raw)
|
||||
}
|
||||
|
||||
// kingbaseIdentNeedsQuote 判断标识符是否需要双引号包裹。
|
||||
// 与前端 sql.ts 中 needsQuote 逻辑保持一致。
|
||||
func kingbaseIdentNeedsQuote(ident string) bool {
|
||||
if ident == "" {
|
||||
return false
|
||||
}
|
||||
// 不是合法裸标识符格式(必须以字母或下划线开头,仅含字母、数字、下划线)
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, ident); !matched {
|
||||
return true
|
||||
}
|
||||
// 包含大写字母时需要引号保护(KingbaseES/PostgreSQL 默认将未加引号的标识符折叠为小写)
|
||||
for _, r := range ident {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 是 SQL 保留字
|
||||
return isKingbaseReservedWord(ident)
|
||||
}
|
||||
|
||||
// isKingbaseReservedWord 检查是否为常见 SQL 保留字(简化版,与前端保持一致)。
|
||||
func isKingbaseReservedWord(ident string) bool {
|
||||
switch strings.ToLower(ident) {
|
||||
case "select", "from", "where", "table", "index", "user", "order", "group", "by",
|
||||
"limit", "offset", "and", "or", "not", "null", "true", "false", "key",
|
||||
"primary", "foreign", "references", "default", "constraint",
|
||||
"create", "drop", "alter", "insert", "update", "delete", "set", "values", "into",
|
||||
"join", "left", "right", "inner", "outer", "on", "as", "is", "in", "like",
|
||||
"between", "case", "when", "then", "else", "end", "having", "distinct",
|
||||
"all", "any", "exists", "union", "except", "intersect",
|
||||
"column", "check", "unique", "with", "grant", "revoke", "trigger",
|
||||
"begin", "commit", "rollback", "schema", "database", "view", "function",
|
||||
"procedure", "sequence", "type", "domain", "role", "session", "current",
|
||||
"authorization", "cross", "full", "natural", "some", "cast", "fetch",
|
||||
"for", "to", "do", "if", "return", "returns", "declare", "cursor", "server", "owner":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func quoteKingbaseIdent(name string) string {
|
||||
n := normalizeKingbaseIdentifier(name)
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if !kingbaseIdentNeedsQuote(n) {
|
||||
return n
|
||||
}
|
||||
n = strings.ReplaceAll(n, `"`, `""`)
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
func splitKingbaseQualifiedTable(tableName string) (schema string, table string) {
|
||||
return splitKingbaseQualifiedNameCommon(tableName)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
// dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
|
||||
query := `
|
||||
|
||||
117
internal/db/kingbase_impl_test.go
Normal file
117
internal/db/kingbase_impl_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build gonavi_full_drivers || gonavi_kingbase_driver
|
||||
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeKingbaseIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "plain", in: "ldf_server", want: "ldf_server"},
|
||||
{name: "quoted", in: `"ldf_server"`, want: "ldf_server"},
|
||||
{name: "double quoted", in: `""ldf_server""`, want: "ldf_server"},
|
||||
{name: "quad quoted", in: `""""ldf_server""""`, want: "ldf_server"},
|
||||
{name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"},
|
||||
{name: "double escaped quoted", in: `\\\"ldf_server\\\"`, want: "ldf_server"},
|
||||
{name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"},
|
||||
{name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"},
|
||||
{name: "embedded double quotes", in: `ldf""server`, want: "ldfserver"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := normalizeKingbaseIdentifier(tt.in); got != tt.want {
|
||||
t.Fatalf("normalizeKingbaseIdentifier(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteKingbaseIdent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
// 纯小写+下划线:不加引号
|
||||
{name: "plain lowercase", in: "ldf_server", want: "ldf_server"},
|
||||
{name: "plain lowercase 2", in: "bcs_barcode", want: "bcs_barcode"},
|
||||
{name: "double quoted input", in: `""ldf_server""`, want: "ldf_server"},
|
||||
{name: "escaped quoted input", in: `\"ldf_server\"`, want: "ldf_server"},
|
||||
// 含大写字母:加引号
|
||||
{name: "uppercase", in: "LDF_Server", want: `"LDF_Server"`},
|
||||
{name: "mixed case", in: "myTable", want: `"myTable"`},
|
||||
// SQL 保留字:加引号
|
||||
{name: "reserved word order", in: "order", want: `"order"`},
|
||||
{name: "reserved word user", in: "user", want: `"user"`},
|
||||
{name: "reserved word table", in: "table", want: `"table"`},
|
||||
{name: "reserved word select", in: "select", want: `"select"`},
|
||||
// 含特殊字符:加引号
|
||||
{name: "with hyphen", in: "my-table", want: `"my-table"`},
|
||||
{name: "with space", in: "my table", want: `"my table"`},
|
||||
{name: "with embedded quote", in: `ab"cd`, want: `"ab""cd"`},
|
||||
// 空值
|
||||
{name: "empty", in: "", want: `""`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := quoteKingbaseIdent(tt.in); got != tt.want {
|
||||
t.Fatalf("quoteKingbaseIdent(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKingbaseIdentNeedsQuote(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{name: "plain lowercase", in: "ldf_server", want: false},
|
||||
{name: "starts with underscore", in: "_col", want: false},
|
||||
{name: "with digits", in: "col123", want: false},
|
||||
{name: "uppercase", in: "MyTable", want: true},
|
||||
{name: "reserved word", in: "order", want: true},
|
||||
{name: "with hyphen", in: "my-col", want: true},
|
||||
{name: "starts with digit", in: "123col", want: true},
|
||||
{name: "empty", in: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := kingbaseIdentNeedsQuote(tt.in); got != tt.want {
|
||||
t.Fatalf("kingbaseIdentNeedsQuote(%q) = %v, want %v", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitKingbaseQualifiedTable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
wantSchema string
|
||||
wantTable string
|
||||
}{
|
||||
{name: "plain qualified", in: "ldf_server.t_user", wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "double quoted qualified", in: `""ldf_server"".""t_user""`, wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "escaped qualified", in: `\"ldf_server\".\"t_user\"`, wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "double escaped qualified", in: `\\\"ldf_server\\\".\\\"t_user\\\"`, wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "bracket qualified", in: "[ldf_server].[t_user]", wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "table only", in: `""t_user""`, wantSchema: "", wantTable: "t_user"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotSchema, gotTable := splitKingbaseQualifiedTable(tt.in)
|
||||
if gotSchema != tt.wantSchema || gotTable != tt.wantTable {
|
||||
t.Fatalf("splitKingbaseQualifiedTable(%q) = (%q, %q), want (%q, %q)", tt.in, gotSchema, gotTable, tt.wantSchema, tt.wantTable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -25,30 +24,33 @@ type MariaDB struct {
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (m *MariaDB) getDSN(config connection.ConnectionConfig) string {
|
||||
func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
database := config.Database
|
||||
protocol := "tcp"
|
||||
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
protocol = netName
|
||||
address = fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
} else {
|
||||
logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := m.getDSN(config)
|
||||
dsn, err := m.getDSN(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
@@ -71,7 +73,7 @@ func (m *MariaDB) Close() error {
|
||||
|
||||
func (m *MariaDB) Ping() error {
|
||||
if m.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := m.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -82,9 +84,33 @@ func (m *MariaDB) Ping() error {
|
||||
return m.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (m *MariaDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
if m.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := m.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (m *MariaDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
if m.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := m.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (m *MariaDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := m.conn.QueryContext(ctx, query)
|
||||
@@ -98,7 +124,7 @@ func (m *MariaDB) QueryContext(ctx context.Context, query string) ([]map[string]
|
||||
|
||||
func (m *MariaDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := m.conn.Query(query)
|
||||
@@ -111,7 +137,7 @@ func (m *MariaDB) Query(query string) ([]map[string]interface{}, []string, error
|
||||
|
||||
func (m *MariaDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := m.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -122,7 +148,7 @@ func (m *MariaDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (m *MariaDB) Exec(query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := m.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -184,7 +210,7 @@ func (m *MariaDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
@@ -250,12 +276,22 @@ func (m *MariaDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefini
|
||||
}
|
||||
}
|
||||
|
||||
subPart := 0
|
||||
if val, ok := row["Sub_part"]; ok && val != nil {
|
||||
if f, ok := val.(float64); ok {
|
||||
subPart = int(f)
|
||||
} else if i, ok := val.(int64); ok {
|
||||
subPart = int(i)
|
||||
}
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["Key_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["Column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: fmt.Sprintf("%v", row["Index_type"]),
|
||||
SubPart: subPart,
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
@@ -308,7 +344,7 @@ func (m *MariaDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDef
|
||||
|
||||
func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if m.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := m.conn.Begin()
|
||||
@@ -323,14 +359,14 @@ func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
var args []interface{}
|
||||
for k, v := range pk {
|
||||
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,7 +377,7 @@ func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
|
||||
for k, v := range update.Values {
|
||||
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
@@ -351,16 +387,16 @@ func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
var wheres []string
|
||||
for k, v := range update.Keys {
|
||||
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,7 +409,7 @@ func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
for k, v := range row {
|
||||
cols = append(cols, fmt.Sprintf("`%s`", k))
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
@@ -382,7 +418,7 @@ func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -392,7 +428,7 @@ func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
func (m *MariaDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
|
||||
if dbName == "" {
|
||||
return nil, fmt.Errorf("database name required for GetAllColumns")
|
||||
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
|
||||
}
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
|
||||
@@ -151,10 +151,14 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf
|
||||
}
|
||||
}
|
||||
|
||||
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
|
||||
explicitHost := strings.TrimSpace(config.Host) != ""
|
||||
explicitHosts := len(config.Hosts) > 0
|
||||
|
||||
// 显式填写的 host/hosts 优先级高于 URI,避免表单 host 被 URI 中的 localhost 覆盖。
|
||||
if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 {
|
||||
config.Hosts = hostsFromURI
|
||||
}
|
||||
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
|
||||
if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 {
|
||||
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
|
||||
if ok {
|
||||
config.Host = host
|
||||
@@ -233,9 +237,6 @@ func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
|
||||
params.Set("serverSelectionTimeoutMS", strconv.Itoa(timeout*1000))
|
||||
|
||||
authSource := strings.TrimSpace(config.AuthSource)
|
||||
if authSource == "" && strings.TrimSpace(config.Database) != "" {
|
||||
authSource = strings.TrimSpace(config.Database)
|
||||
}
|
||||
if authSource == "" {
|
||||
authSource = "admin"
|
||||
}
|
||||
@@ -251,6 +252,11 @@ func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
|
||||
params.Set("authMechanism", authMechanism)
|
||||
}
|
||||
|
||||
// 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现
|
||||
if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV {
|
||||
params.Set("directConnection", "true")
|
||||
}
|
||||
|
||||
if encoded := params.Encode(); encoded != "" {
|
||||
uri += "?" + encoded
|
||||
}
|
||||
@@ -276,9 +282,44 @@ func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.Con
|
||||
return attempts
|
||||
}
|
||||
|
||||
func mongoURIForcesTLS(uriText string) bool {
|
||||
trimmed := strings.TrimSpace(uriText)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
query := parsed.Query()
|
||||
for _, key := range []string{"tls", "ssl"} {
|
||||
value := strings.ToLower(strings.TrimSpace(query.Get(key)))
|
||||
switch value {
|
||||
case "1", "true", "t", "yes", "y", "required":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mongoAttemptSSLLabel(config connection.ConnectionConfig, fallbackToPlain bool) string {
|
||||
if fallbackToPlain {
|
||||
return "明文回退"
|
||||
}
|
||||
if mongoURIForcesTLS(config.URI) {
|
||||
return "SSL"
|
||||
}
|
||||
enabled, _ := resolveMongoTLSSettings(config)
|
||||
if enabled {
|
||||
return "SSL"
|
||||
}
|
||||
return "明文"
|
||||
}
|
||||
|
||||
func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
|
||||
runConfig := applyMongoURI(config)
|
||||
connectConfig := runConfig
|
||||
sshRouteHint := ""
|
||||
|
||||
if runConfig.UseSSH && runConfig.MongoSRV {
|
||||
return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道")
|
||||
@@ -319,6 +360,7 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.URI = ""
|
||||
localConfig.Hosts = []string{normalizeMongoAddress(host, port)}
|
||||
connectConfig = localConfig
|
||||
sshRouteHint = fmt.Sprintf("SSH隧道 %s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort)
|
||||
logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort)
|
||||
}
|
||||
|
||||
@@ -332,20 +374,32 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
|
||||
if shouldTrySSLPreferredFallback(connectConfig) {
|
||||
sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig))
|
||||
}
|
||||
totalAttempts := 0
|
||||
for _, attemptConfig := range sslAttempts {
|
||||
totalAttempts += len(buildMongoAuthAttempts(attemptConfig))
|
||||
}
|
||||
attemptNo := 0
|
||||
|
||||
var errorDetails []string
|
||||
for sslIndex, sslConfig := range sslAttempts {
|
||||
sslLabel := "SSL"
|
||||
if sslIndex > 0 {
|
||||
sslLabel = "明文回退"
|
||||
}
|
||||
sslLabel := mongoAttemptSSLLabel(sslConfig, sslIndex > 0)
|
||||
|
||||
attemptConfigs := buildMongoAuthAttempts(sslConfig)
|
||||
for index, attemptConfig := range attemptConfigs {
|
||||
attemptNo++
|
||||
authLabel := "主库凭据"
|
||||
if index > 0 {
|
||||
authLabel = "从库凭据"
|
||||
}
|
||||
targets := collectMongoSeeds(attemptConfig)
|
||||
if len(targets) == 0 {
|
||||
targets = append(targets, normalizeMongoAddress(attemptConfig.Host, attemptConfig.Port))
|
||||
}
|
||||
attemptStarted := time.Now()
|
||||
logger.Infof(
|
||||
"MongoDB 连接尝试:%d/%d 模式=%s 凭据=%s 目标=%s 代理=%t",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, strings.Join(targets, ","), attemptConfig.UseProxy,
|
||||
)
|
||||
|
||||
if sslIndex > 0 {
|
||||
attemptConfig.URI = ""
|
||||
@@ -364,7 +418,13 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
client, err := mongo.Connect(clientOpts)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err))
|
||||
logger.Warnf("MongoDB 连接尝试失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err)
|
||||
detail := fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)
|
||||
if sshRouteHint != "" {
|
||||
detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint)
|
||||
}
|
||||
errorDetails = append(errorDetails, detail)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -374,9 +434,17 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
|
||||
_ = client.Disconnect(ctx)
|
||||
cancel()
|
||||
m.client = nil
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err))
|
||||
logger.Warnf("MongoDB 连接尝试验证失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err)
|
||||
detail := fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)
|
||||
if sshRouteHint != "" {
|
||||
detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint)
|
||||
}
|
||||
errorDetails = append(errorDetails, detail)
|
||||
continue
|
||||
}
|
||||
logger.Infof("MongoDB 连接尝试成功:%d/%d 模式=%s 凭据=%s 耗时=%s",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond))
|
||||
if sslIndex > 0 {
|
||||
logger.Warnf("MongoDB SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
@@ -409,7 +477,7 @@ func (m *MongoDB) Close() error {
|
||||
|
||||
func (m *MongoDB) Ping() error {
|
||||
if m.client == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := m.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -613,7 +681,7 @@ func buildMembersFromHello(raw bson.M) []connection.MongoMemberInfo {
|
||||
|
||||
func (m *MongoDB) DiscoverMembers() (string, []connection.MongoMemberInfo, error) {
|
||||
if m.client == nil {
|
||||
return "", nil, fmt.Errorf("connection not open")
|
||||
return "", nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
timeout := m.pingTimeout
|
||||
@@ -764,7 +832,7 @@ func extractCollectionFromSQL(sql string) string {
|
||||
|
||||
func (m *MongoDB) queryWithContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.client == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
query = strings.TrimSpace(query)
|
||||
@@ -1008,7 +1076,7 @@ func (m *MongoDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (m *MongoDB) GetDatabases() ([]string, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -1023,7 +1091,7 @@ func (m *MongoDB) GetDatabases() ([]string, error) {
|
||||
|
||||
func (m *MongoDB) GetTables(dbName string) ([]string, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
targetDB := dbName
|
||||
@@ -1059,7 +1127,7 @@ func (m *MongoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWit
|
||||
// GetIndexes returns indexes for a MongoDB collection
|
||||
func (m *MongoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
targetDB := dbName
|
||||
@@ -1126,7 +1194,7 @@ func (m *MongoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDef
|
||||
// ApplyChanges implements batch changes for MongoDB
|
||||
func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if m.client == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
@@ -1142,7 +1210,7 @@ func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
}
|
||||
if len(filter) > 0 {
|
||||
if _, err := collection.DeleteOne(ctx, filter); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1154,7 +1222,7 @@ func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
filter[k] = v
|
||||
}
|
||||
if len(filter) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{"$set": bson.M{}}
|
||||
@@ -1163,7 +1231,7 @@ func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
}
|
||||
|
||||
if _, err := collection.UpdateOne(ctx, filter, updateDoc); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1175,7 +1243,7 @@ func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
}
|
||||
if len(doc) > 0 {
|
||||
if _, err := collection.InsertOne(ctx, doc); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
39
internal/db/mongodb_impl_uri_test.go
Normal file
39
internal/db/mongodb_impl_uri_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build gonavi_full_drivers || gonavi_mongodb_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestApplyMongoURI_ExplicitHostDoesNotAdoptURIHosts(t *testing.T) {
|
||||
config := connection.ConnectionConfig{
|
||||
Host: "10.10.10.10",
|
||||
Port: 27017,
|
||||
URI: "mongodb://localhost:27017/admin",
|
||||
}
|
||||
|
||||
got := applyMongoURI(config)
|
||||
if got.Host != "10.10.10.10" {
|
||||
t.Fatalf("expected host to remain explicit, got %q", got.Host)
|
||||
}
|
||||
if len(got.Hosts) != 0 {
|
||||
t.Fatalf("expected hosts to remain empty when explicit host exists, got %v", got.Hosts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyMongoURI_ExplicitHostsDoesNotAdoptURIHosts(t *testing.T) {
|
||||
config := connection.ConnectionConfig{
|
||||
Host: "10.10.10.10",
|
||||
Port: 27017,
|
||||
Hosts: []string{"10.10.10.10:27017", "10.10.10.11:27017"},
|
||||
URI: "mongodb://localhost:27017,localhost:27018/admin?replicaSet=rs0",
|
||||
}
|
||||
|
||||
got := applyMongoURI(config)
|
||||
if len(got.Hosts) != 2 || got.Hosts[0] != "10.10.10.10:27017" {
|
||||
t.Fatalf("expected explicit hosts to stay untouched, got %v", got.Hosts)
|
||||
}
|
||||
}
|
||||
@@ -152,10 +152,14 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf
|
||||
}
|
||||
}
|
||||
|
||||
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
|
||||
explicitHost := strings.TrimSpace(config.Host) != ""
|
||||
explicitHosts := len(config.Hosts) > 0
|
||||
|
||||
// 显式填写的 host/hosts 优先级高于 URI,避免表单 host 被 URI 中的 localhost 覆盖。
|
||||
if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 {
|
||||
config.Hosts = hostsFromURI
|
||||
}
|
||||
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
|
||||
if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 {
|
||||
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
|
||||
if ok {
|
||||
config.Host = host
|
||||
@@ -234,9 +238,6 @@ func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string {
|
||||
params.Set("serverSelectionTimeoutMS", strconv.Itoa(timeout*1000))
|
||||
|
||||
authSource := strings.TrimSpace(config.AuthSource)
|
||||
if authSource == "" && strings.TrimSpace(config.Database) != "" {
|
||||
authSource = strings.TrimSpace(config.Database)
|
||||
}
|
||||
if authSource == "" {
|
||||
authSource = "admin"
|
||||
}
|
||||
@@ -252,6 +253,11 @@ func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string {
|
||||
params.Set("authMechanism", authMechanism)
|
||||
}
|
||||
|
||||
// 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现
|
||||
if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV {
|
||||
params.Set("directConnection", "true")
|
||||
}
|
||||
|
||||
if encoded := params.Encode(); encoded != "" {
|
||||
uri += "?" + encoded
|
||||
}
|
||||
@@ -277,9 +283,44 @@ func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.Con
|
||||
return attempts
|
||||
}
|
||||
|
||||
func mongoURIForcesTLS(uriText string) bool {
|
||||
trimmed := strings.TrimSpace(uriText)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
query := parsed.Query()
|
||||
for _, key := range []string{"tls", "ssl"} {
|
||||
value := strings.ToLower(strings.TrimSpace(query.Get(key)))
|
||||
switch value {
|
||||
case "1", "true", "t", "yes", "y", "required":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mongoAttemptSSLLabel(config connection.ConnectionConfig, fallbackToPlain bool) string {
|
||||
if fallbackToPlain {
|
||||
return "明文回退"
|
||||
}
|
||||
if mongoURIForcesTLS(config.URI) {
|
||||
return "SSL"
|
||||
}
|
||||
enabled, _ := resolveMongoTLSSettings(config)
|
||||
if enabled {
|
||||
return "SSL"
|
||||
}
|
||||
return "明文"
|
||||
}
|
||||
|
||||
func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error {
|
||||
runConfig := applyMongoURI(config)
|
||||
connectConfig := runConfig
|
||||
sshRouteHint := ""
|
||||
|
||||
if runConfig.UseSSH && runConfig.MongoSRV {
|
||||
return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道")
|
||||
@@ -320,6 +361,7 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.URI = ""
|
||||
localConfig.Hosts = []string{normalizeMongoAddress(host, port)}
|
||||
connectConfig = localConfig
|
||||
sshRouteHint = fmt.Sprintf("SSH隧道 %s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort)
|
||||
logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort)
|
||||
}
|
||||
|
||||
@@ -333,20 +375,32 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error {
|
||||
if shouldTrySSLPreferredFallback(connectConfig) {
|
||||
sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig))
|
||||
}
|
||||
totalAttempts := 0
|
||||
for _, attemptConfig := range sslAttempts {
|
||||
totalAttempts += len(buildMongoAuthAttempts(attemptConfig))
|
||||
}
|
||||
attemptNo := 0
|
||||
|
||||
var errorDetails []string
|
||||
for sslIndex, sslConfig := range sslAttempts {
|
||||
sslLabel := "SSL"
|
||||
if sslIndex > 0 {
|
||||
sslLabel = "明文回退"
|
||||
}
|
||||
sslLabel := mongoAttemptSSLLabel(sslConfig, sslIndex > 0)
|
||||
|
||||
attemptConfigs := buildMongoAuthAttempts(sslConfig)
|
||||
for index, attemptConfig := range attemptConfigs {
|
||||
attemptNo++
|
||||
authLabel := "主库凭据"
|
||||
if index > 0 {
|
||||
authLabel = "从库凭据"
|
||||
}
|
||||
targets := collectMongoSeeds(attemptConfig)
|
||||
if len(targets) == 0 {
|
||||
targets = append(targets, normalizeMongoAddress(attemptConfig.Host, attemptConfig.Port))
|
||||
}
|
||||
attemptStarted := time.Now()
|
||||
logger.Infof(
|
||||
"MongoDB(v1) 连接尝试:%d/%d 模式=%s 凭据=%s 目标=%s 代理=%t",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, strings.Join(targets, ","), attemptConfig.UseProxy,
|
||||
)
|
||||
|
||||
if sslIndex > 0 {
|
||||
attemptConfig.URI = ""
|
||||
@@ -367,7 +421,13 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error {
|
||||
client, err := mongo.Connect(connectCtx, clientOpts)
|
||||
connectCancel()
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err))
|
||||
logger.Warnf("MongoDB(v1) 连接尝试失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err)
|
||||
detail := fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)
|
||||
if sshRouteHint != "" {
|
||||
detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint)
|
||||
}
|
||||
errorDetails = append(errorDetails, detail)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -377,9 +437,17 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error {
|
||||
_ = client.Disconnect(ctx)
|
||||
cancel()
|
||||
m.client = nil
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err))
|
||||
logger.Warnf("MongoDB(v1) 连接尝试验证失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err)
|
||||
detail := fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)
|
||||
if sshRouteHint != "" {
|
||||
detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint)
|
||||
}
|
||||
errorDetails = append(errorDetails, detail)
|
||||
continue
|
||||
}
|
||||
logger.Infof("MongoDB(v1) 连接尝试成功:%d/%d 模式=%s 凭据=%s 耗时=%s",
|
||||
attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond))
|
||||
if sslIndex > 0 {
|
||||
logger.Warnf("MongoDB(v1) SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
@@ -412,7 +480,7 @@ func (m *MongoDBV1) Close() error {
|
||||
|
||||
func (m *MongoDBV1) Ping() error {
|
||||
if m.client == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := m.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -616,7 +684,7 @@ func buildMembersFromHello(raw bson.M) []connection.MongoMemberInfo {
|
||||
|
||||
func (m *MongoDBV1) DiscoverMembers() (string, []connection.MongoMemberInfo, error) {
|
||||
if m.client == nil {
|
||||
return "", nil, fmt.Errorf("connection not open")
|
||||
return "", nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
timeout := m.pingTimeout
|
||||
@@ -767,7 +835,7 @@ func extractCollectionFromSQL(sql string) string {
|
||||
|
||||
func (m *MongoDBV1) queryWithContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.client == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
query = strings.TrimSpace(query)
|
||||
@@ -1011,7 +1079,7 @@ func (m *MongoDBV1) ExecContext(ctx context.Context, query string) (int64, error
|
||||
|
||||
func (m *MongoDBV1) GetDatabases() ([]string, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -1026,7 +1094,7 @@ func (m *MongoDBV1) GetDatabases() ([]string, error) {
|
||||
|
||||
func (m *MongoDBV1) GetTables(dbName string) ([]string, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
targetDB := dbName
|
||||
@@ -1062,7 +1130,7 @@ func (m *MongoDBV1) GetAllColumns(dbName string) ([]connection.ColumnDefinitionW
|
||||
// GetIndexes returns indexes for a MongoDB collection
|
||||
func (m *MongoDBV1) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
targetDB := dbName
|
||||
@@ -1129,7 +1197,7 @@ func (m *MongoDBV1) GetTriggers(dbName, tableName string) ([]connection.TriggerD
|
||||
// ApplyChanges implements batch changes for MongoDB
|
||||
func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if m.client == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
@@ -1145,7 +1213,7 @@ func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
if len(filter) > 0 {
|
||||
if _, err := collection.DeleteOne(ctx, filter); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1157,7 +1225,7 @@ func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
filter[k] = v
|
||||
}
|
||||
if len(filter) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{"$set": bson.M{}}
|
||||
@@ -1166,7 +1234,7 @@ func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
|
||||
if _, err := collection.UpdateOne(ctx, filter, updateDoc); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1178,7 +1246,7 @@ func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
if len(doc) > 0 {
|
||||
if _, err := collection.InsertOne(ctx, doc); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
25
internal/db/mongodb_impl_v1_uri_test.go
Normal file
25
internal/db/mongodb_impl_v1_uri_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build gonavi_mongodb_driver_v1
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestApplyMongoURIV1_ExplicitHostDoesNotAdoptURIHosts(t *testing.T) {
|
||||
config := connection.ConnectionConfig{
|
||||
Host: "10.10.10.10",
|
||||
Port: 27017,
|
||||
URI: "mongodb://localhost:27017/admin",
|
||||
}
|
||||
|
||||
got := applyMongoURI(config)
|
||||
if got.Host != "10.10.10.10" {
|
||||
t.Fatalf("expected host to remain explicit, got %q", got.Host)
|
||||
}
|
||||
if len(got.Hosts) != 0 {
|
||||
t.Fatalf("expected hosts to remain empty when explicit host exists, got %v", got.Hosts)
|
||||
}
|
||||
}
|
||||
@@ -429,7 +429,7 @@ func (m *MySQLAgentDB) ApplyChanges(tableName string, changes connection.ChangeS
|
||||
|
||||
func (m *MySQLAgentDB) requireClient() (*mysqlAgentClient, error) {
|
||||
if m.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
return m.client, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@@ -168,26 +169,26 @@ func collectMySQLAddresses(config connection.ConnectionConfig) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
|
||||
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
database := config.Database
|
||||
protocol := "tcp"
|
||||
address := normalizeMySQLAddress(config.Host, config.Port)
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
protocol = netName
|
||||
address = normalizeMySQLAddress(config.Host, config.Port)
|
||||
} else {
|
||||
logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
|
||||
), nil
|
||||
}
|
||||
|
||||
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
@@ -225,7 +226,11 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
|
||||
candidateConfig.Port = port
|
||||
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
|
||||
|
||||
dsn := m.getDSN(candidateConfig)
|
||||
dsn, err := m.getDSN(candidateConfig)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
|
||||
continue
|
||||
}
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
|
||||
@@ -262,7 +267,7 @@ func (m *MySQLDB) Close() error {
|
||||
|
||||
func (m *MySQLDB) Ping() error {
|
||||
if m.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := m.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -273,9 +278,33 @@ func (m *MySQLDB) Ping() error {
|
||||
return m.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
if m.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := m.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
if m.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := m.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := m.conn.QueryContext(ctx, query)
|
||||
@@ -289,7 +318,7 @@ func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]
|
||||
|
||||
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := m.conn.Query(query)
|
||||
@@ -302,7 +331,7 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error
|
||||
|
||||
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := m.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -313,7 +342,7 @@ func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (m *MySQLDB) Exec(query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := m.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -375,7 +404,7 @@ func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (m *MySQLDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
@@ -441,12 +470,22 @@ func (m *MySQLDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefini
|
||||
}
|
||||
}
|
||||
|
||||
subPart := 0
|
||||
if val, ok := row["Sub_part"]; ok && val != nil {
|
||||
if f, ok := val.(float64); ok {
|
||||
subPart = int(f)
|
||||
} else if i, ok := val.(int64); ok {
|
||||
subPart = int(i)
|
||||
}
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["Key_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["Column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: fmt.Sprintf("%v", row["Index_type"]),
|
||||
SubPart: subPart,
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
@@ -499,7 +538,7 @@ func (m *MySQLDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDef
|
||||
|
||||
func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if m.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
columnTypeMap := m.loadColumnTypeMap(tableName)
|
||||
@@ -524,7 +563,7 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("删除未生效:未匹配到任何行")
|
||||
@@ -552,13 +591,13 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("更新未生效:未匹配到任何行")
|
||||
@@ -585,7 +624,7 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
query := fmt.Sprintf("INSERT INTO `%s` () VALUES ()", tableName)
|
||||
res, err := tx.Exec(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("插入未生效:未影响任何行")
|
||||
@@ -596,7 +635,7 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("插入未生效:未影响任何行")
|
||||
@@ -606,6 +645,18 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func normalizeMySQLComplexValue(value interface{}) interface{} {
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}, []interface{}:
|
||||
if data, err := json.Marshal(v); err == nil {
|
||||
return string(data)
|
||||
}
|
||||
return fmt.Sprintf("%v", value)
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMySQLDateTimeValue(value interface{}) interface{} {
|
||||
text, ok := value.(string)
|
||||
if !ok {
|
||||
@@ -670,7 +721,7 @@ func (m *MySQLDB) loadColumnTypeMap(tableName string) map[string]string {
|
||||
func normalizeMySQLValueForInsert(columnName string, value interface{}, columnTypeMap map[string]string) (interface{}, bool) {
|
||||
columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]))
|
||||
if !isMySQLTemporalColumnType(columnType) {
|
||||
return value, false
|
||||
return normalizeMySQLComplexValue(value), false
|
||||
}
|
||||
text, ok := value.(string)
|
||||
if ok && strings.TrimSpace(text) == "" {
|
||||
@@ -747,7 +798,7 @@ func formatMySQLDateTime(t time.Time) string {
|
||||
func (m *MySQLDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
|
||||
if dbName == "" {
|
||||
return nil, fmt.Errorf("database name required for GetAllColumns")
|
||||
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
|
||||
}
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
|
||||
26
internal/db/mysql_ssh_test.go
Normal file
26
internal/db/mysql_ssh_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestMySQLDSN_UseSSH_ShouldFailWhenSSHInvalid(t *testing.T) {
|
||||
m := &MySQLDB{}
|
||||
_, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
UseSSH: true,
|
||||
SSH: connection.SSHConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 0, // invalid port, should fail immediately
|
||||
User: "bad",
|
||||
Password: "bad",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error when UseSSH=true and SSH config invalid")
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -94,6 +97,9 @@ func newOptionalDriverAgentClient(driverType string, executablePath string) (*op
|
||||
return nil, fmt.Errorf("创建 %s 驱动代理 stderr 失败:%w", driverDisplayName(driverType), err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
if isWindowsExecutableMachineMismatch(err) {
|
||||
return nil, fmt.Errorf("启动 %s 驱动代理失败:%w(检测到驱动代理与当前系统架构不兼容,请在驱动管理中重新安装启用)", driverDisplayName(driverType), err)
|
||||
}
|
||||
return nil, fmt.Errorf("启动 %s 驱动代理失败:%w", driverDisplayName(driverType), err)
|
||||
}
|
||||
|
||||
@@ -107,6 +113,30 @@ func newOptionalDriverAgentClient(driverType string, executablePath string) (*op
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func isWindowsExecutableMachineMismatch(err error) bool {
|
||||
if err == nil || runtime.GOOS != "windows" {
|
||||
return false
|
||||
}
|
||||
var errno syscall.Errno
|
||||
if errors.As(err, &errno) && errno == syscall.Errno(216) {
|
||||
return true
|
||||
}
|
||||
text := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(text, "not compatible with the version of windows") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(text, "win32") && strings.Contains(text, "compatible") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(text, "不是有效的win32应用程序") || strings.Contains(text, "无法在win32模式下运行") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *optionalDriverAgentClient) captureStderr(stderr io.Reader) {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
buffer := make([]byte, 0, 8<<10)
|
||||
@@ -116,6 +146,7 @@ func (c *optionalDriverAgentClient) captureStderr(stderr io.Reader) {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
logger.Warnf("%s 驱动代理 stderr: %s", driverDisplayName(c.driver), line)
|
||||
c.stderrMu.Lock()
|
||||
if c.stderr.Len() > 0 {
|
||||
c.stderr.WriteString(" | ")
|
||||
@@ -239,6 +270,7 @@ func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) erro
|
||||
return err
|
||||
}
|
||||
d.client = client
|
||||
d.ensureKingbaseSearchPath(config)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -459,6 +491,16 @@ func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connectio
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.EqualFold(d.driverType, "kingbase") {
|
||||
if normalized := normalizeKingbaseAgentTableName(tableName); normalized != "" {
|
||||
tableName = normalized
|
||||
}
|
||||
if normalized, normErr := d.normalizeKingbaseAgentChangeSet(tableName, changes); normErr == nil {
|
||||
changes = normalized
|
||||
} else {
|
||||
logger.Warnf("Kingbase ApplyChanges 字段名规范化失败:%v", normErr)
|
||||
}
|
||||
}
|
||||
return client.call(optionalAgentRequest{
|
||||
Method: optionalAgentMethodApplyChanges,
|
||||
TableName: tableName,
|
||||
@@ -468,11 +510,255 @@ func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connectio
|
||||
|
||||
func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, error) {
|
||||
if d.client == nil {
|
||||
return nil, fmt.Errorf("connection not open")
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
return d.client, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) ensureKingbaseSearchPath(config connection.ConnectionConfig) {
|
||||
if !strings.EqualFold(d.driverType, "kingbase") {
|
||||
return
|
||||
}
|
||||
client, err := d.requireClient()
|
||||
if err != nil || client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
schemas, err := d.listKingbaseSchemas(ctx)
|
||||
if err != nil || len(schemas) == 0 {
|
||||
if err != nil {
|
||||
logger.Warnf("人大金仓驱动代理探测 schema 失败:%v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
searchPath := buildKingbaseSearchPathFromSchemas(schemas)
|
||||
if strings.TrimSpace(searchPath) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := d.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPath)); err != nil {
|
||||
logger.Warnf("人大金仓驱动代理设置 search_path 失败:%v", err)
|
||||
return
|
||||
}
|
||||
logger.Infof("人大金仓驱动代理已设置默认 search_path:%s", searchPath)
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) listKingbaseSchemas(ctx context.Context) ([]string, error) {
|
||||
query := `SELECT nspname FROM pg_namespace
|
||||
WHERE nspname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND nspname NOT LIKE 'pg_%'
|
||||
ORDER BY nspname`
|
||||
rows, _, err := d.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schemas := make([]string, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
for key, val := range row {
|
||||
if strings.EqualFold(key, "nspname") || strings.EqualFold(key, "schema") {
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if name != "" {
|
||||
schemas = append(schemas, name)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(row) == 1 {
|
||||
for _, val := range row {
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if name != "" {
|
||||
schemas = append(schemas, name)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return schemas, nil
|
||||
}
|
||||
|
||||
func buildKingbaseSearchPathFromSchemas(schemas []string) string {
|
||||
searchPath, _ := buildKingbaseSearchPathCommon(schemas)
|
||||
return searchPath
|
||||
}
|
||||
|
||||
func quoteKingbaseAgentIdent(name string) string {
|
||||
n := normalizeKingbaseAgentIdent(name)
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
n = strings.ReplaceAll(n, `"`, `""`)
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
func normalizeKingbaseAgentTableName(raw string) string {
|
||||
schema, table := splitKingbaseQualifiedNameCommon(raw)
|
||||
if table == "" {
|
||||
return ""
|
||||
}
|
||||
if schema == "" {
|
||||
return table
|
||||
}
|
||||
return schema + "." + table
|
||||
}
|
||||
|
||||
func normalizeKingbaseAgentIdent(raw string) string {
|
||||
return normalizeKingbaseIdentCommon(raw)
|
||||
}
|
||||
|
||||
type kingbaseAgentColumnIndex struct {
|
||||
exact map[string]string
|
||||
compact map[string]string
|
||||
}
|
||||
|
||||
func buildKingbaseAgentColumnIndex(columns []string) kingbaseAgentColumnIndex {
|
||||
exact := make(map[string]string, len(columns))
|
||||
compact := make(map[string]string, len(columns))
|
||||
compactSeen := make(map[string]string, len(columns))
|
||||
compactDup := make(map[string]struct{}, len(columns))
|
||||
|
||||
for _, col := range columns {
|
||||
name := normalizeKingbaseAgentIdent(col)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(name)
|
||||
if _, ok := exact[lower]; !ok {
|
||||
exact[lower] = name
|
||||
}
|
||||
key := normalizeKingbaseAgentCompactKey(name)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if prev, ok := compactSeen[key]; ok && !strings.EqualFold(prev, name) {
|
||||
compactDup[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
compactSeen[key] = name
|
||||
}
|
||||
|
||||
if len(compactDup) > 0 {
|
||||
for key := range compactDup {
|
||||
delete(compactSeen, key)
|
||||
}
|
||||
}
|
||||
for key, value := range compactSeen {
|
||||
compact[key] = value
|
||||
}
|
||||
return kingbaseAgentColumnIndex{exact: exact, compact: compact}
|
||||
}
|
||||
|
||||
func normalizeKingbaseAgentCompactKey(raw string) string {
|
||||
name := normalizeKingbaseAgentIdent(raw)
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
name = strings.Join(strings.Fields(name), "")
|
||||
name = strings.ReplaceAll(name, "_", "")
|
||||
return name
|
||||
}
|
||||
|
||||
func resolveKingbaseAgentColumnName(name string, index kingbaseAgentColumnIndex) string {
|
||||
cleaned := normalizeKingbaseAgentIdent(name)
|
||||
if cleaned == "" {
|
||||
return name
|
||||
}
|
||||
lower := strings.ToLower(cleaned)
|
||||
if actual, ok := index.exact[lower]; ok {
|
||||
return actual
|
||||
}
|
||||
compact := normalizeKingbaseAgentCompactKey(cleaned)
|
||||
if actual, ok := index.compact[compact]; ok {
|
||||
return actual
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func normalizeKingbaseAgentChangeSetByColumns(changes connection.ChangeSet, columns []string) (connection.ChangeSet, error) {
|
||||
index := buildKingbaseAgentColumnIndex(columns)
|
||||
if len(index.exact) == 0 && len(index.compact) == 0 {
|
||||
return changes, nil
|
||||
}
|
||||
|
||||
mapRow := func(row map[string]interface{}) (map[string]interface{}, error) {
|
||||
if row == nil {
|
||||
return row, nil
|
||||
}
|
||||
out := make(map[string]interface{}, len(row))
|
||||
for key, value := range row {
|
||||
nextKey := resolveKingbaseAgentColumnName(key, index)
|
||||
if existing, ok := out[nextKey]; ok && !reflect.DeepEqual(existing, value) {
|
||||
return nil, fmt.Errorf("duplicate mapped column %q", nextKey)
|
||||
}
|
||||
out[nextKey] = value
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
next := connection.ChangeSet{
|
||||
Inserts: make([]map[string]interface{}, 0, len(changes.Inserts)),
|
||||
Updates: make([]connection.UpdateRow, 0, len(changes.Updates)),
|
||||
Deletes: make([]map[string]interface{}, 0, len(changes.Deletes)),
|
||||
}
|
||||
|
||||
for _, row := range changes.Inserts {
|
||||
mapped, err := mapRow(row)
|
||||
if err != nil {
|
||||
return changes, err
|
||||
}
|
||||
next.Inserts = append(next.Inserts, mapped)
|
||||
}
|
||||
|
||||
for _, upd := range changes.Updates {
|
||||
keys, err := mapRow(upd.Keys)
|
||||
if err != nil {
|
||||
return changes, err
|
||||
}
|
||||
values, err := mapRow(upd.Values)
|
||||
if err != nil {
|
||||
return changes, err
|
||||
}
|
||||
next.Updates = append(next.Updates, connection.UpdateRow{
|
||||
Keys: keys,
|
||||
Values: values,
|
||||
})
|
||||
}
|
||||
|
||||
for _, row := range changes.Deletes {
|
||||
mapped, err := mapRow(row)
|
||||
if err != nil {
|
||||
return changes, err
|
||||
}
|
||||
next.Deletes = append(next.Deletes, mapped)
|
||||
}
|
||||
|
||||
return next, nil
|
||||
}
|
||||
|
||||
func (d *OptionalDriverAgentDB) normalizeKingbaseAgentChangeSet(tableName string, changes connection.ChangeSet) (connection.ChangeSet, error) {
|
||||
columns, err := d.GetColumns("", tableName)
|
||||
if err != nil {
|
||||
return changes, err
|
||||
}
|
||||
if len(columns) == 0 {
|
||||
return changes, nil
|
||||
}
|
||||
names := make([]string, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
name := strings.TrimSpace(col.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return normalizeKingbaseAgentChangeSetByColumns(changes, names)
|
||||
}
|
||||
|
||||
func timeoutMsFromContext(ctx context.Context) int64 {
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
|
||||
@@ -1,32 +1,67 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestTimeoutMsFromContext_NoDeadline(t *testing.T) {
|
||||
if got := timeoutMsFromContext(context.Background()); got != 0 {
|
||||
t.Fatalf("无 deadline 时应返回 0,got=%d", got)
|
||||
func TestNormalizeKingbaseAgentTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "plain", in: "ldf_server.andon_events", want: "ldf_server.andon_events"},
|
||||
{name: "quoted", in: `"ldf_server"."andon_events"`, want: "ldf_server.andon_events"},
|
||||
{name: "double quoted", in: `""ldf_server"".""andon_events""`, want: "ldf_server.andon_events"},
|
||||
{name: "escaped", in: `\"ldf_server\".\"andon_events\"`, want: "ldf_server.andon_events"},
|
||||
{name: "double escaped", in: `\\\"ldf_server\\\".\\\"andon_events\\\"`, want: "ldf_server.andon_events"},
|
||||
{name: "space around dot", in: ` "ldf_server" . "andon_events" `, want: "ldf_server.andon_events"},
|
||||
{name: "table only", in: `bcs_barcode`, want: "bcs_barcode"},
|
||||
{name: "table only quoted", in: `"bcs_barcode"`, want: "bcs_barcode"},
|
||||
{name: "table only double quoted", in: `""bcs_barcode""`, want: "bcs_barcode"},
|
||||
{name: "table only double escaped", in: `\\\"bcs_barcode\\\"`, want: "bcs_barcode"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := normalizeKingbaseAgentTableName(tt.in); got != tt.want {
|
||||
t.Fatalf("normalizeKingbaseAgentTableName(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutMsFromContext_WithDeadline(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
func TestNormalizeKingbaseAgentChangeSetByColumns(t *testing.T) {
|
||||
columns := []string{"andon_events_id", "event_name", "event_code"}
|
||||
input := connection.ChangeSet{
|
||||
Inserts: []map[string]interface{}{
|
||||
{"event name": "物料1", "event_code": "EV-0001", "andon_events_id": 1},
|
||||
},
|
||||
Updates: []connection.UpdateRow{
|
||||
{Keys: map[string]interface{}{"andon_events_id": 1}, Values: map[string]interface{}{"event name": "物料2"}},
|
||||
},
|
||||
Deletes: []map[string]interface{}{
|
||||
{"andon_events_id": 1},
|
||||
},
|
||||
}
|
||||
|
||||
got := timeoutMsFromContext(ctx)
|
||||
if got <= 0 {
|
||||
t.Fatalf("有 deadline 时应返回正值,got=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutMsFromContext_ExpiredDeadline(t *testing.T) {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||
defer cancel()
|
||||
|
||||
if got := timeoutMsFromContext(ctx); got != 1 {
|
||||
t.Fatalf("过期 deadline 应返回 1,got=%d", got)
|
||||
out, err := normalizeKingbaseAgentChangeSetByColumns(input, columns)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKingbaseAgentChangeSetByColumns error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := out.Inserts[0]["event_name"]; !ok {
|
||||
t.Fatalf("expected insert to map \"event name\" -> \"event_name\"")
|
||||
}
|
||||
if _, ok := out.Inserts[0]["event name"]; ok {
|
||||
t.Fatalf("unexpected insert key \"event name\" after normalization")
|
||||
}
|
||||
if _, ok := out.Updates[0].Values["event_name"]; !ok {
|
||||
t.Fatalf("expected update values to map \"event name\" -> \"event_name\"")
|
||||
}
|
||||
if _, ok := out.Updates[0].Values["event name"]; ok {
|
||||
t.Fatalf("unexpected update value key \"event name\" after normalization")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (o *OracleDB) Close() error {
|
||||
|
||||
func (o *OracleDB) Ping() error {
|
||||
if o.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := o.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -148,7 +148,7 @@ func (o *OracleDB) Ping() error {
|
||||
|
||||
func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if o.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := o.conn.QueryContext(ctx, query)
|
||||
@@ -162,7 +162,7 @@ func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string
|
||||
|
||||
func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if o.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := o.conn.Query(query)
|
||||
@@ -175,7 +175,7 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
|
||||
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if o.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := o.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -186,7 +186,7 @@ func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (o *OracleDB) Exec(query string) (int64, error) {
|
||||
if o.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := o.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -259,7 +259,7 @@ func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error)
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
@@ -391,7 +391,7 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
|
||||
|
||||
func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if o.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := o.conn.Begin()
|
||||
@@ -439,7 +439,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -467,12 +467,12 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -496,7 +496,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -195,7 +195,7 @@ func (p *PostgresDB) Close() error {
|
||||
|
||||
func (p *PostgresDB) Ping() error {
|
||||
if p.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := p.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -208,7 +208,7 @@ func (p *PostgresDB) Ping() error {
|
||||
|
||||
func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := p.conn.QueryContext(ctx, query)
|
||||
@@ -222,7 +222,7 @@ func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[stri
|
||||
|
||||
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := p.conn.Query(query)
|
||||
@@ -235,7 +235,7 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
|
||||
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if p.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := p.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -246,7 +246,7 @@ func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, erro
|
||||
|
||||
func (p *PostgresDB) Exec(query string) (int64, error) {
|
||||
if p.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := p.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -302,7 +302,7 @@ func (p *PostgresDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -372,7 +372,7 @@ func (p *PostgresDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -478,7 +478,7 @@ func (p *PostgresDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -538,7 +538,7 @@ func (p *PostgresDB) GetTriggers(dbName, tableName string) ([]connection.Trigger
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -602,7 +602,7 @@ ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if p.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := p.conn.Begin()
|
||||
@@ -650,7 +650,7 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -678,12 +678,12 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -707,7 +707,7 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@@ -30,12 +31,44 @@ func normalizeQueryValue(v interface{}) interface{} {
|
||||
}
|
||||
|
||||
func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} {
|
||||
if tm, ok := v.(time.Time); ok {
|
||||
return normalizeTemporalValueForDisplay(tm, databaseTypeName)
|
||||
}
|
||||
if b, ok := v.([]byte); ok {
|
||||
return bytesToDisplayValue(b, databaseTypeName)
|
||||
}
|
||||
return normalizeCompositeQueryValue(v)
|
||||
}
|
||||
|
||||
func normalizeTemporalValueForDisplay(value time.Time, databaseTypeName string) interface{} {
|
||||
if value.IsZero() {
|
||||
if zeroValue, ok := zeroTemporalDisplayValue(databaseTypeName); ok {
|
||||
return zeroValue
|
||||
}
|
||||
}
|
||||
return value.Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func zeroTemporalDisplayValue(databaseTypeName string) (string, bool) {
|
||||
typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
if typeName == "" {
|
||||
return "0000-00-00 00:00:00", true
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(typeName, "TIMESTAMP") || strings.Contains(typeName, "DATETIME"):
|
||||
return "0000-00-00 00:00:00", true
|
||||
case typeName == "DATE" || typeName == "NEWDATE":
|
||||
return "0000-00-00", true
|
||||
case strings.Contains(typeName, "TIME"):
|
||||
return "00:00:00", true
|
||||
case strings.Contains(typeName, "YEAR"):
|
||||
return "0000", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeCompositeQueryValue(v interface{}) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
@@ -86,6 +119,16 @@ func normalizeCompositeQueryValue(v interface{}) interface{} {
|
||||
items[i] = normalizeQueryValue(rv.Index(i).Interface())
|
||||
}
|
||||
return items
|
||||
case reflect.Struct:
|
||||
// 部分驱动(如 Kingbase)会返回复杂结构体值,直接透传会导致前端渲染和比较开销激增。
|
||||
// 统一降级为可读字符串,避免对象深层序列化触发 UI 卡顿。
|
||||
if tm, ok := v.(time.Time); ok {
|
||||
return normalizeTemporalValueForDisplay(tm, "")
|
||||
}
|
||||
if stringer, ok := v.(fmt.Stringer); ok {
|
||||
return stringer.String()
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
default:
|
||||
return normalizeUnsafeIntegerForJS(rv, v)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package db
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type duckMapLike map[any]any
|
||||
@@ -165,3 +167,61 @@ func TestNormalizeQueryValueWithDBType_JSONNumber(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type customStructValue struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
func (v customStructValue) String() string {
|
||||
return fmt.Sprintf("%s-%d", v.Name, v.Age)
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_StructToString(t *testing.T) {
|
||||
got := normalizeQueryValueWithDBType(customStructValue{Name: "alice", Age: 18}, "")
|
||||
if got != "alice-18" {
|
||||
t.Fatalf("结构体应降级为可读字符串,实际=%v(%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_TimeStructToRFC3339(t *testing.T) {
|
||||
input := time.Date(2026, 3, 5, 18, 30, 15, 123456789, time.UTC)
|
||||
got := normalizeQueryValueWithDBType(input, "")
|
||||
text, ok := got.(string)
|
||||
if !ok {
|
||||
t.Fatalf("time.Time 应转为字符串,实际=%v(%T)", got, got)
|
||||
}
|
||||
if text != "2026-03-05T18:30:15.123456789Z" {
|
||||
t.Fatalf("time.Time 规整值异常,实际=%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_ZeroTemporalValues(t *testing.T) {
|
||||
zero := time.Time{}
|
||||
cases := []struct {
|
||||
name string
|
||||
dbType string
|
||||
wantText string
|
||||
}{
|
||||
{name: "date", dbType: "DATE", wantText: "0000-00-00"},
|
||||
{name: "newdate", dbType: "NEWDATE", wantText: "0000-00-00"},
|
||||
{name: "datetime", dbType: "DATETIME", wantText: "0000-00-00 00:00:00"},
|
||||
{name: "timestamp", dbType: "TIMESTAMP", wantText: "0000-00-00 00:00:00"},
|
||||
{name: "time", dbType: "TIME", wantText: "00:00:00"},
|
||||
{name: "year", dbType: "YEAR", wantText: "0000"},
|
||||
{name: "unknown", dbType: "", wantText: "0000-00-00 00:00:00"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := normalizeQueryValueWithDBType(zero, tc.dbType)
|
||||
text, ok := got.(string)
|
||||
if !ok {
|
||||
t.Fatalf("期望 string,实际=%v(%T)", got, got)
|
||||
}
|
||||
if text != tc.wantText {
|
||||
t.Fatalf("dbType=%s 期望=%s,实际=%s", tc.dbType, tc.wantText, text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
|
||||
@@ -44,3 +46,38 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
|
||||
}
|
||||
return resultData, columns, nil
|
||||
}
|
||||
|
||||
// scanMultiRows 遍历 sql.Rows 中的所有结果集,将每个结果集作为 ResultSetData 返回。
|
||||
// 利用 rows.NextResultSet() 支持一次 query 返回多个结果集的场景。
|
||||
func scanMultiRows(rows *sql.Rows) ([]connection.ResultSetData, error) {
|
||||
var results []connection.ResultSetData
|
||||
for {
|
||||
data, cols, err := scanRows(rows)
|
||||
if err != nil {
|
||||
return results, err
|
||||
}
|
||||
if data == nil {
|
||||
data = make([]map[string]interface{}, 0)
|
||||
}
|
||||
if cols == nil {
|
||||
cols = []string{}
|
||||
}
|
||||
results = append(results, connection.ResultSetData{
|
||||
Rows: data,
|
||||
Columns: cols,
|
||||
})
|
||||
if !rows.NextResultSet() {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(results) == 0 {
|
||||
results = []connection.ResultSetData{{
|
||||
Rows: make([]map[string]interface{}, 0),
|
||||
Columns: []string{},
|
||||
}}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return results, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -184,7 +184,7 @@ func (s *SQLiteDB) Close() error {
|
||||
|
||||
func (s *SQLiteDB) Ping() error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := s.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -197,7 +197,7 @@ func (s *SQLiteDB) Ping() error {
|
||||
|
||||
func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
@@ -211,7 +211,7 @@ func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string
|
||||
|
||||
func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := s.conn.Query(query)
|
||||
@@ -224,7 +224,7 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
|
||||
func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := s.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -235,7 +235,7 @@ func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error)
|
||||
|
||||
func (s *SQLiteDB) Exec(query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := s.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -275,13 +275,13 @@ func (s *SQLiteDB) GetCreateStatement(dbName, tableName string) (string, error)
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
@@ -372,7 +372,7 @@ func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi
|
||||
func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
@@ -463,7 +463,7 @@ func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin
|
||||
func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
@@ -537,7 +537,7 @@ func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.Foreig
|
||||
func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
@@ -588,7 +588,7 @@ func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
|
||||
|
||||
func (s *SQLiteDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := s.conn.Begin()
|
||||
@@ -634,7 +634,7 @@ func (s *SQLiteDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -659,12 +659,12 @@ func (s *SQLiteDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -686,7 +686,7 @@ func (s *SQLiteDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ func (s *SqlServerDB) Close() error {
|
||||
|
||||
func (s *SqlServerDB) Ping() error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := s.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -128,9 +128,33 @@ func (s *SqlServerDB) Ping() error {
|
||||
return s.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
if s.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := s.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
if s.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
@@ -144,7 +168,7 @@ func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[str
|
||||
|
||||
func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := s.conn.Query(query)
|
||||
@@ -157,7 +181,7 @@ func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, e
|
||||
|
||||
func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := s.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -168,7 +192,7 @@ func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, err
|
||||
|
||||
func (s *SqlServerDB) Exec(query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := s.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -236,7 +260,7 @@ func (s *SqlServerDB) GetColumns(dbName, tableName string) ([]connection.ColumnD
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -344,7 +368,7 @@ func (s *SqlServerDB) GetIndexes(dbName, tableName string) ([]connection.IndexDe
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -425,7 +449,7 @@ func (s *SqlServerDB) GetForeignKeys(dbName, tableName string) ([]connection.For
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -483,7 +507,7 @@ func (s *SqlServerDB) GetTriggers(dbName, tableName string) ([]connection.Trigge
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -530,7 +554,7 @@ ORDER BY tr.name`,
|
||||
|
||||
func (s *SqlServerDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := s.conn.Begin()
|
||||
@@ -573,7 +597,7 @@ func (s *SqlServerDB) ApplyChanges(tableName string, changes connection.ChangeSe
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -601,12 +625,12 @@ func (s *SqlServerDB) ApplyChanges(tableName string, changes connection.ChangeSe
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -630,7 +654,7 @@ func (s *SqlServerDB) ApplyChanges(tableName string, changes connection.ChangeSe
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
168
internal/db/tdengine_applychanges_test.go
Normal file
168
internal/db/tdengine_applychanges_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
//go:build gonavi_full_drivers || gonavi_tdengine_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
const tdengineRecordingDriverName = "gonavi_tdengine_recording"
|
||||
|
||||
var (
|
||||
registerTDengineRecordingDriverOnce sync.Once
|
||||
tdengineRecordingDriverMu sync.Mutex
|
||||
tdengineRecordingDriverSeq int
|
||||
tdengineRecordingDriverStates = map[string]*tdengineRecordingState{}
|
||||
)
|
||||
|
||||
type tdengineRecordingState struct {
|
||||
mu sync.Mutex
|
||||
queries []string
|
||||
execErr error
|
||||
}
|
||||
|
||||
func (s *tdengineRecordingState) snapshotQueries() []string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
queries := make([]string, len(s.queries))
|
||||
copy(queries, s.queries)
|
||||
return queries
|
||||
}
|
||||
|
||||
type tdengineRecordingDriver struct{}
|
||||
|
||||
func (tdengineRecordingDriver) Open(name string) (driver.Conn, error) {
|
||||
tdengineRecordingDriverMu.Lock()
|
||||
state := tdengineRecordingDriverStates[name]
|
||||
tdengineRecordingDriverMu.Unlock()
|
||||
if state == nil {
|
||||
return nil, fmt.Errorf("recording state not found: %s", name)
|
||||
}
|
||||
return &tdengineRecordingConn{state: state}, nil
|
||||
}
|
||||
|
||||
type tdengineRecordingConn struct {
|
||||
state *tdengineRecordingState
|
||||
}
|
||||
|
||||
func (c *tdengineRecordingConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return nil, fmt.Errorf("prepare not supported in tdengine recording driver: %s", query)
|
||||
}
|
||||
|
||||
func (c *tdengineRecordingConn) Close() error { return nil }
|
||||
|
||||
func (c *tdengineRecordingConn) Begin() (driver.Tx, error) {
|
||||
return nil, fmt.Errorf("transactions not supported in tdengine recording driver")
|
||||
}
|
||||
|
||||
func (c *tdengineRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if len(args) > 0 {
|
||||
return nil, fmt.Errorf("unexpected exec args: %d", len(args))
|
||||
}
|
||||
c.state.mu.Lock()
|
||||
defer c.state.mu.Unlock()
|
||||
if c.state.execErr != nil {
|
||||
return nil, c.state.execErr
|
||||
}
|
||||
c.state.queries = append(c.state.queries, query)
|
||||
return driver.RowsAffected(1), nil
|
||||
}
|
||||
|
||||
var _ driver.ExecerContext = (*tdengineRecordingConn)(nil)
|
||||
|
||||
func openTDengineRecordingDB(t *testing.T) (*sql.DB, *tdengineRecordingState) {
|
||||
t.Helper()
|
||||
registerTDengineRecordingDriverOnce.Do(func() {
|
||||
sql.Register(tdengineRecordingDriverName, tdengineRecordingDriver{})
|
||||
})
|
||||
|
||||
tdengineRecordingDriverMu.Lock()
|
||||
tdengineRecordingDriverSeq++
|
||||
dsn := fmt.Sprintf("tdengine-recording-%d", tdengineRecordingDriverSeq)
|
||||
state := &tdengineRecordingState{}
|
||||
tdengineRecordingDriverStates[dsn] = state
|
||||
tdengineRecordingDriverMu.Unlock()
|
||||
|
||||
dbConn, err := sql.Open(tdengineRecordingDriverName, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("打开 recording db 失败: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = dbConn.Close()
|
||||
tdengineRecordingDriverMu.Lock()
|
||||
delete(tdengineRecordingDriverStates, dsn)
|
||||
tdengineRecordingDriverMu.Unlock()
|
||||
})
|
||||
|
||||
return dbConn, state
|
||||
}
|
||||
|
||||
func TestTDengineApplyChanges_InsertsIntoQualifiedTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openTDengineRecordingDB(t)
|
||||
td := &TDengineDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Inserts: []map[string]interface{}{
|
||||
{
|
||||
"ts": "2026-03-09 10:00:00",
|
||||
"value": 12.5,
|
||||
"device": "sensor-a",
|
||||
"enabled": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := td.ApplyChanges("analytics.metrics", changes); err != nil {
|
||||
t.Fatalf("ApplyChanges 返回错误: %v", err)
|
||||
}
|
||||
|
||||
queries := state.snapshotQueries()
|
||||
if len(queries) != 1 {
|
||||
t.Fatalf("期望执行 1 条 SQL,实际 %d 条: %#v", len(queries), queries)
|
||||
}
|
||||
|
||||
want := "INSERT INTO `analytics`.`metrics` (`device`, `enabled`, `ts`, `value`) VALUES ('sensor-a', 1, '2026-03-09 10:00:00', 12.5)"
|
||||
if queries[0] != want {
|
||||
t.Fatalf("插入 SQL 不符合预期\nwant: %s\n got: %s", want, queries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDengineApplyChanges_RejectsMixedUpdatesWithoutPartialWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openTDengineRecordingDB(t)
|
||||
td := &TDengineDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Inserts: []map[string]interface{}{{
|
||||
"ts": "2026-03-09 10:00:00",
|
||||
"value": 12.5,
|
||||
}},
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{"ts": "2026-03-09 10:00:00"},
|
||||
Values: map[string]interface{}{"value": 18.8},
|
||||
}},
|
||||
}
|
||||
|
||||
err := td.ApplyChanges("metrics", changes)
|
||||
if err == nil {
|
||||
t.Fatalf("期望 mixed changes 被拒绝")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "UPDATE/DELETE") {
|
||||
t.Fatalf("错误信息未说明限制边界: %v", err)
|
||||
}
|
||||
if queries := state.snapshotQueries(); len(queries) != 0 {
|
||||
t.Fatalf("期望拒绝 mixed changes 时不执行任何 SQL,实际=%#v", queries)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -119,7 +120,7 @@ func (t *TDengineDB) Close() error {
|
||||
|
||||
func (t *TDengineDB) Ping() error {
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := t.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -132,7 +133,7 @@ func (t *TDengineDB) Ping() error {
|
||||
|
||||
func (t *TDengineDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if t.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := t.conn.QueryContext(ctx, query)
|
||||
@@ -146,7 +147,7 @@ func (t *TDengineDB) QueryContext(ctx context.Context, query string) ([]map[stri
|
||||
|
||||
func (t *TDengineDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if t.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := t.conn.Query(query)
|
||||
@@ -160,7 +161,7 @@ func (t *TDengineDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
|
||||
func (t *TDengineDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if t.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := t.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -171,7 +172,7 @@ func (t *TDengineDB) ExecContext(ctx context.Context, query string) (int64, erro
|
||||
|
||||
func (t *TDengineDB) Exec(query string) (int64, error) {
|
||||
if t.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := t.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -273,7 +274,7 @@ func (t *TDengineDB) GetCreateStatement(dbName, tableName string) (string, error
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
return "", fmt.Errorf("未找到建表语句")
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
@@ -324,7 +325,7 @@ func (t *TDengineDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
|
||||
func (t *TDengineDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
if strings.TrimSpace(dbName) == "" {
|
||||
return nil, fmt.Errorf("database name required for GetAllColumns")
|
||||
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
|
||||
}
|
||||
|
||||
tables, err := t.GetTables(dbName)
|
||||
@@ -362,6 +363,83 @@ func (t *TDengineDB) GetTriggers(dbName, tableName string) ([]connection.Trigger
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
if strings.TrimSpace(tableName) == "" {
|
||||
return fmt.Errorf("表名不能为空")
|
||||
}
|
||||
if len(changes.Updates) > 0 || len(changes.Deletes) > 0 {
|
||||
return fmt.Errorf("TDengine 目标端当前仅支持 INSERT 写入,暂不支持 UPDATE/DELETE 差异同步,请改用仅插入或全量覆盖模式")
|
||||
}
|
||||
|
||||
qualifiedTable := quoteTDengineTable("", tableName)
|
||||
for _, row := range changes.Inserts {
|
||||
query, err := buildTDengineInsertSQL(qualifiedTable, row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if query == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := t.conn.Exec(query); err != nil {
|
||||
return fmt.Errorf("插入失败:%v; sql=%s", err, query)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildTDengineInsertSQL(qualifiedTable string, row map[string]interface{}) (string, error) {
|
||||
if strings.TrimSpace(qualifiedTable) == "" {
|
||||
return "", fmt.Errorf("需要指定完整的表名")
|
||||
}
|
||||
if len(row) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cols := make([]string, 0, len(row))
|
||||
for key := range row {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
cols = append(cols, key)
|
||||
}
|
||||
if len(cols) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
sort.Strings(cols)
|
||||
|
||||
quotedCols := make([]string, 0, len(cols))
|
||||
values := make([]string, 0, len(cols))
|
||||
for _, col := range cols {
|
||||
quotedCols = append(quotedCols, fmt.Sprintf("`%s`", escapeBacktickIdent(col)))
|
||||
values = append(values, tdengineLiteral(row[col]))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", ")), nil
|
||||
}
|
||||
|
||||
func tdengineLiteral(value interface{}) string {
|
||||
switch val := value.(type) {
|
||||
case nil:
|
||||
return "NULL"
|
||||
case bool:
|
||||
if val {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
return fmt.Sprintf("%v", val)
|
||||
case time.Time:
|
||||
return fmt.Sprintf("'%s'", val.Format("2006-01-02 15:04:05"))
|
||||
case []byte:
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(string(val), "'", "''"))
|
||||
default:
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprintf("%v", val), "'", "''"))
|
||||
}
|
||||
}
|
||||
|
||||
func getValueFromRow(row map[string]interface{}, keys ...string) (interface{}, bool) {
|
||||
if len(row) == 0 {
|
||||
return nil, false
|
||||
|
||||
@@ -124,7 +124,7 @@ func (v *VastbaseDB) Close() error {
|
||||
|
||||
func (v *VastbaseDB) Ping() error {
|
||||
if v.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := v.pingTimeout
|
||||
if timeout <= 0 {
|
||||
@@ -137,7 +137,7 @@ func (v *VastbaseDB) Ping() error {
|
||||
|
||||
func (v *VastbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if v.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := v.conn.QueryContext(ctx, query)
|
||||
@@ -151,7 +151,7 @@ func (v *VastbaseDB) QueryContext(ctx context.Context, query string) ([]map[stri
|
||||
|
||||
func (v *VastbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if v.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := v.conn.Query(query)
|
||||
@@ -164,7 +164,7 @@ func (v *VastbaseDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
|
||||
func (v *VastbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if v.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := v.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
@@ -175,7 +175,7 @@ func (v *VastbaseDB) ExecContext(ctx context.Context, query string) (int64, erro
|
||||
|
||||
func (v *VastbaseDB) Exec(query string) (int64, error) {
|
||||
if v.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := v.conn.Exec(query)
|
||||
if err != nil {
|
||||
@@ -231,7 +231,7 @@ func (v *VastbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -301,7 +301,7 @@ func (v *VastbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -406,7 +406,7 @@ func (v *VastbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -466,7 +466,7 @@ func (v *VastbaseDB) GetTriggers(dbName, tableName string) ([]connection.Trigger
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
return nil, fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
@@ -530,7 +530,7 @@ ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
func (v *VastbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if v.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
tx, err := v.conn.Begin()
|
||||
@@ -578,7 +578,7 @@ func (v *VastbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -606,12 +606,12 @@ func (v *VastbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -635,7 +635,7 @@ func (v *VastbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
envLogDir = "GONAVI_LOG_DIR"
|
||||
appDirName = "GoNavi"
|
||||
envLogDir = "GONAVI_LOG_DIR"
|
||||
appHiddenDir = ".GoNavi"
|
||||
appLogDirName = "Logs"
|
||||
|
||||
logFileName = "gonavi.log"
|
||||
logRotateMaxBytes = 10 * 1024 * 1024 // 10MB
|
||||
@@ -37,7 +38,7 @@ func Init() {
|
||||
defer logMu.Unlock()
|
||||
logPath = path
|
||||
logInst = log.New(out, "", log.Ldate|log.Ltime|log.Lmicroseconds)
|
||||
logInst.Printf("[信息] 日志初始化完成,日志文件:%s", logPath)
|
||||
logInst.Printf("[INFO] 日志初始化完成,日志文件:%s", logPath)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -62,15 +63,15 @@ func Close() {
|
||||
}
|
||||
|
||||
func Infof(format string, args ...any) {
|
||||
printf("信息", format, args...)
|
||||
printf("INFO", format, args...)
|
||||
}
|
||||
|
||||
func Warnf(format string, args ...any) {
|
||||
printf("警告", format, args...)
|
||||
printf("WARN", format, args...)
|
||||
}
|
||||
|
||||
func Errorf(format string, args ...any) {
|
||||
printf("错误", format, args...)
|
||||
printf("ERROR", format, args...)
|
||||
}
|
||||
|
||||
func Error(err error, format string, args ...any) {
|
||||
@@ -115,37 +116,58 @@ func ErrorChain(err error) string {
|
||||
func printf(level string, format string, args ...any) {
|
||||
Init()
|
||||
logMu.Lock()
|
||||
defer logMu.Unlock()
|
||||
inst := logInst
|
||||
logMu.Unlock()
|
||||
if inst == nil {
|
||||
return
|
||||
}
|
||||
inst.Printf("[%s] %s", level, fmt.Sprintf(format, args...))
|
||||
if logFile != nil {
|
||||
_ = logFile.Sync()
|
||||
}
|
||||
}
|
||||
|
||||
func initOutput() (string, io.Writer) {
|
||||
dir := strings.TrimSpace(os.Getenv(envLogDir))
|
||||
if dir == "" {
|
||||
base, err := os.UserConfigDir()
|
||||
if err != nil || strings.TrimSpace(base) == "" {
|
||||
base = os.TempDir()
|
||||
}
|
||||
dir = filepath.Join(base, appDirName, "logs")
|
||||
dir = defaultLogDir()
|
||||
}
|
||||
|
||||
if path, writer, ok := openLogFile(dir); ok {
|
||||
return path, writer
|
||||
}
|
||||
|
||||
fallbackDir := filepath.Join(os.TempDir(), appHiddenDir, appLogDirName)
|
||||
if path, writer, ok := openLogFile(fallbackDir); ok {
|
||||
return path, writer
|
||||
}
|
||||
|
||||
return "", os.Stderr
|
||||
}
|
||||
|
||||
func defaultLogDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || strings.TrimSpace(home) == "" {
|
||||
return filepath.Join(os.TempDir(), appHiddenDir, appLogDirName)
|
||||
}
|
||||
return filepath.Join(home, appHiddenDir, appLogDirName)
|
||||
}
|
||||
|
||||
func openLogFile(dir string) (string, io.Writer, bool) {
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
return "", nil, false
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return filepath.Join(dir, logFileName), os.Stderr
|
||||
return "", nil, false
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, logFileName)
|
||||
rotateIfNeeded(path, dir)
|
||||
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return path, os.Stderr
|
||||
return "", nil, false
|
||||
}
|
||||
logFile = f
|
||||
return path, f
|
||||
return path, f, true
|
||||
}
|
||||
|
||||
func rotateIfNeeded(path, dir string) {
|
||||
|
||||
65
internal/logger/logger_test.go
Normal file
65
internal/logger/logger_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorChain_NilError(t *testing.T) {
|
||||
if got := ErrorChain(nil); got != "" {
|
||||
t.Errorf("ErrorChain(nil) = %q; want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorChain_SingleError(t *testing.T) {
|
||||
err := errors.New("single error")
|
||||
got := ErrorChain(err)
|
||||
if got != "single error" {
|
||||
t.Errorf("ErrorChain(single) = %q; want %q", got, "single error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorChain_WrappedErrors(t *testing.T) {
|
||||
inner := errors.New("root cause")
|
||||
middle := fmt.Errorf("middle: %w", inner)
|
||||
outer := fmt.Errorf("outer: %w", middle)
|
||||
|
||||
got := ErrorChain(outer)
|
||||
// Should contain all three distinct messages
|
||||
if got == "" {
|
||||
t.Fatal("ErrorChain returned empty string for wrapped errors")
|
||||
}
|
||||
// The chain should start with the outermost error
|
||||
if len(got) < len("outer:") {
|
||||
t.Errorf("ErrorChain result too short: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorChain_DeduplicatesMessages(t *testing.T) {
|
||||
// Create a chain where wrapping doesn't add new text
|
||||
inner := errors.New("same message")
|
||||
outer := fmt.Errorf("%w", inner)
|
||||
|
||||
got := ErrorChain(outer)
|
||||
// Should not repeat "same message"
|
||||
if got != "same message" {
|
||||
t.Errorf("ErrorChain should deduplicate: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorChain_TruncatesLongChain(t *testing.T) {
|
||||
// Build a chain of 25 errors (exceeds the 20-level limit)
|
||||
var err error = errors.New("base")
|
||||
for i := 0; i < 25; i++ {
|
||||
err = fmt.Errorf("level-%d: %w", i, err)
|
||||
}
|
||||
got := ErrorChain(err)
|
||||
if got == "" {
|
||||
t.Fatal("ErrorChain returned empty for long chain")
|
||||
}
|
||||
// Should contain truncation notice
|
||||
if len(got) == 0 {
|
||||
t.Error("expected non-empty result for long chain")
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user