mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-12 12:19:47 +08:00
Compare commits
41 Commits
feature/0.
...
release/0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
538e4a1506 | ||
|
|
934581c796 | ||
|
|
1486b98d27 | ||
|
|
6cda430f03 | ||
|
|
f56c3d5f6e | ||
|
|
74c9143c95 | ||
|
|
0e4a833ffa | ||
|
|
37ad9885b7 | ||
|
|
5cef9a4032 | ||
|
|
f49767c38b | ||
|
|
7e8699ba02 | ||
|
|
5f0ce5ed7a | ||
|
|
49c7620bdd | ||
|
|
80fa7a1acd | ||
|
|
68770a42e2 | ||
|
|
06aebf716e | ||
|
|
f551b19f40 | ||
|
|
6674ad69e1 | ||
|
|
37d35684f1 | ||
|
|
71e5de0cdc | ||
|
|
d8656c6c9c | ||
|
|
443b487a02 | ||
|
|
bac57ebdf0 | ||
|
|
213a33e4f3 | ||
|
|
a00f87582d | ||
|
|
f129623000 | ||
|
|
8dbc97e466 | ||
|
|
4a0db185c0 | ||
|
|
5793f63ac8 | ||
|
|
8aabc67634 | ||
|
|
34c494ce51 | ||
|
|
178de02783 | ||
|
|
94e5b8d2c6 | ||
|
|
89e2247c05 | ||
|
|
b2ede61b79 | ||
|
|
791425a5a8 | ||
|
|
d7acfd1af9 | ||
|
|
88952e87c1 | ||
|
|
c981a65834 | ||
|
|
b9d9ab5464 | ||
|
|
6b503480cf |
58
.github/ISSUE_TEMPLATE/01-bug_report.yml
vendored
Normal file
58
.github/ISSUE_TEMPLATE/01-bug_report.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: 问题反馈
|
||||
description: 软件问题反馈
|
||||
title: "[Bug] "
|
||||
labels: ["bug"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: searched
|
||||
attributes:
|
||||
label: 已经搜索过 Issues,未发现重复问题*
|
||||
options:
|
||||
- label: 我已经搜索过 Issues,没有发现重复问题
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: system
|
||||
attributes:
|
||||
label: 操作系统及版本
|
||||
placeholder: Windows 10 22H2 / macOS Mojave / Linux
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: 软件安装版本
|
||||
placeholder: v0.2.3
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: 问题简述及复现流程
|
||||
description: 请详细描述你遇到的问题,并提供复现步骤
|
||||
placeholder: |
|
||||
1. 打开软件
|
||||
2. 点击 xxx
|
||||
3. 预期结果是 ...
|
||||
4. 实际结果是 ...
|
||||
5. 截图 ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: 其他补充
|
||||
description: 如果你有额外信息,请在此填写
|
||||
placeholder: 可选
|
||||
|
||||
- type: checkboxes
|
||||
id: pr
|
||||
attributes:
|
||||
label: 是否愿意提交 PR 修复当前 Issue
|
||||
options:
|
||||
- label: 我愿意尝试提交 PR
|
||||
37
.github/ISSUE_TEMPLATE/02-feature_request.yml
vendored
Normal file
37
.github/ISSUE_TEMPLATE/02-feature_request.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: 功能建议
|
||||
description: 添加全新功能或改进现有功能
|
||||
title: "[Enhancement] "
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: searched
|
||||
attributes:
|
||||
label: 已经搜索过 Issues,未发现重复问题*
|
||||
options:
|
||||
- label: 我已经搜索过 Issues,没有发现重复问题
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: feature
|
||||
attributes:
|
||||
label: 功能描述
|
||||
description: 请详细描述你希望添加或改进的功能
|
||||
placeholder: 请描述你想要的功能
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: 其他补充
|
||||
description: 如果你有额外信息,请在此填写
|
||||
placeholder: 可选
|
||||
|
||||
- type: checkboxes
|
||||
id: pr
|
||||
attributes:
|
||||
label: 是否愿意提交 PR 实现当前 Issue
|
||||
options:
|
||||
- label: 我愿意尝试提交 PR
|
||||
30
.github/ISSUE_TEMPLATE/03-generic.yml
vendored
Normal file
30
.github/ISSUE_TEMPLATE/03-generic.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: 其他反馈
|
||||
description: 其他类型反馈、建议或讨论
|
||||
title: "[Question] "
|
||||
labels: ["question"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: searched
|
||||
attributes:
|
||||
label: 已经搜索过 Issues,未发现重复问题*
|
||||
options:
|
||||
- label: 我已经搜索过 Issues,没有发现重复问题
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: 内容
|
||||
description: 请填写你的反馈、建议或讨论内容
|
||||
placeholder: 请描述你的问题或想法
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: 其他补充
|
||||
description: 如果你有额外信息,请在此填写
|
||||
placeholder: 可选
|
||||
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
blank_issues_enabled: false
|
||||
22
.github/workflows/release-winget.yml
vendored
Normal file
22
.github/workflows/release-winget.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: Publish to WinGet
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
required: true
|
||||
description: 'Tag of release you want to publish'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: vedantmgoyal9/winget-releaser@v2
|
||||
with:
|
||||
identifier: Syngnat.GoNavi
|
||||
installers-regex: 'GoNavi-windows-(amd64|arm64)\.exe$'
|
||||
release-tag: ${{ inputs.release_tag || github.ref_name }}
|
||||
token: ${{ secrets.WINGET_TOKEN }}
|
||||
120
.github/workflows/release.yml
vendored
120
.github/workflows/release.yml
vendored
@@ -29,6 +29,13 @@ jobs:
|
||||
platform: windows/amd64
|
||||
artifact_name: GoNavi-windows-amd64
|
||||
asset_ext: .exe
|
||||
- os: windows-latest
|
||||
platform: windows/arm64
|
||||
artifact_name: GoNavi-windows-arm64
|
||||
asset_ext: .exe
|
||||
- os: ubuntu-22.04
|
||||
platform: linux/amd64
|
||||
artifact_name: GoNavi-linux-amd64
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -45,6 +52,36 @@ jobs:
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
# Linux Dependencies (GTK3, WebKit2GTK required by Wails)
|
||||
- name: Install Linux Dependencies
|
||||
if: contains(matrix.platform, 'linux')
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev libfuse2
|
||||
|
||||
# Download linuxdeploy tools for AppImage packaging
|
||||
LINUXDEPLOY_URL="https://github.com/linuxdeploy/linuxdeploy/releases/download/continuous/linuxdeploy-x86_64.AppImage"
|
||||
PLUGIN_URL="https://github.com/linuxdeploy/linuxdeploy-plugin-gtk/releases/download/continuous/linuxdeploy-plugin-gtk-x86_64.AppImage"
|
||||
|
||||
echo "📥 下载 linuxdeploy..."
|
||||
wget --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 --tries=3 \
|
||||
-O /tmp/linuxdeploy "$LINUXDEPLOY_URL" || {
|
||||
echo "⚠️ linuxdeploy 下载失败,AppImage 打包将跳过"
|
||||
touch /tmp/skip-appimage
|
||||
}
|
||||
|
||||
echo "📥 下载 linuxdeploy-plugin-gtk..."
|
||||
wget --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 --tries=3 \
|
||||
-O /tmp/linuxdeploy-plugin-gtk "$PLUGIN_URL" || {
|
||||
echo "⚠️ linuxdeploy-plugin-gtk 下载失败,AppImage 打包将跳过"
|
||||
touch /tmp/skip-appimage
|
||||
}
|
||||
|
||||
if [ ! -f /tmp/skip-appimage ]; then
|
||||
chmod +x /tmp/linuxdeploy /tmp/linuxdeploy-plugin-gtk
|
||||
echo "✅ AppImage 工具准备完成"
|
||||
fi
|
||||
|
||||
- name: Install Wails
|
||||
run: go install -v github.com/wailsapp/wails/v2/cmd/wails@latest
|
||||
|
||||
@@ -107,12 +144,93 @@ jobs:
|
||||
echo "📦 正在移动 $FINAL_EXE 到根目录..."
|
||||
mv "$FINAL_EXE" "../../$FINAL_EXE"
|
||||
|
||||
# Linux Packaging (tar.gz and AppImage)
|
||||
- name: Package Linux
|
||||
if: contains(matrix.platform, 'linux')
|
||||
run: |
|
||||
cd build/bin
|
||||
TARGET="${{ matrix.artifact_name }}"
|
||||
|
||||
if [ ! -f "$TARGET" ]; then
|
||||
echo "❌ 未找到构建产物 '$TARGET'!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
chmod +x "$TARGET"
|
||||
|
||||
# 1. Create tar.gz
|
||||
echo "📦 正在打包 $TARGET.tar.gz..."
|
||||
tar -czvf "$TARGET.tar.gz" "$TARGET"
|
||||
mv "$TARGET.tar.gz" ../../
|
||||
|
||||
# 2. Create AppImage (skip for ARM64 or if tools unavailable)
|
||||
if [ -f /tmp/skip-appimage ]; then
|
||||
echo "⚠️ 跳过 AppImage 打包"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "📦 正在生成 AppImage..."
|
||||
|
||||
# Create AppDir structure
|
||||
mkdir -p AppDir/usr/bin
|
||||
mkdir -p AppDir/usr/share/applications
|
||||
mkdir -p AppDir/usr/share/icons/hicolor/256x256/apps
|
||||
|
||||
cp "$TARGET" AppDir/usr/bin/gonavi
|
||||
|
||||
# Create desktop file
|
||||
printf '%s\n' \
|
||||
'[Desktop Entry]' \
|
||||
'Name=GoNavi' \
|
||||
'Exec=gonavi' \
|
||||
'Icon=gonavi' \
|
||||
'Type=Application' \
|
||||
'Categories=Development;Database;' \
|
||||
'Comment=Database Management Tool' \
|
||||
> AppDir/usr/share/applications/gonavi.desktop
|
||||
|
||||
cp AppDir/usr/share/applications/gonavi.desktop AppDir/gonavi.desktop
|
||||
|
||||
# Create a simple icon (or use existing if available)
|
||||
if [ -f "../../build/appicon.png" ]; then
|
||||
cp "../../build/appicon.png" AppDir/usr/share/icons/hicolor/256x256/apps/gonavi.png
|
||||
cp "../../build/appicon.png" AppDir/gonavi.png
|
||||
else
|
||||
# Create a placeholder icon
|
||||
convert -size 256x256 xc:#336791 -fill white -gravity center -pointsize 48 -annotate 0 "GoNavi" AppDir/gonavi.png || \
|
||||
wget -q "https://via.placeholder.com/256/336791/FFFFFF?text=GoNavi" -O AppDir/gonavi.png || \
|
||||
touch AppDir/gonavi.png
|
||||
cp AppDir/gonavi.png AppDir/usr/share/icons/hicolor/256x256/apps/gonavi.png
|
||||
fi
|
||||
|
||||
# Build AppImage
|
||||
export DEPLOY_GTK_VERSION=3
|
||||
/tmp/linuxdeploy --appdir AppDir --plugin gtk --output appimage || {
|
||||
echo "⚠️ AppImage 生成失败,但 tar.gz 已成功生成"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Rename output
|
||||
mv GoNavi*.AppImage "$TARGET.AppImage" 2>/dev/null || {
|
||||
echo "⚠️ AppImage 重命名失败"
|
||||
exit 0
|
||||
}
|
||||
|
||||
if [ -f "$TARGET.AppImage" ]; then
|
||||
mv "$TARGET.AppImage" ../../
|
||||
echo "✅ AppImage 生成成功"
|
||||
fi
|
||||
|
||||
# Upload to Actions Artifacts (Temporary Storage)
|
||||
- name: Upload Artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: build-artifacts-${{ strategy.job-index }} # Unique name per job
|
||||
path: GoNavi-*${{ matrix.asset_ext }}
|
||||
path: |
|
||||
GoNavi-*.dmg
|
||||
GoNavi-*.exe
|
||||
GoNavi-*.tar.gz
|
||||
GoNavi-*.AppImage
|
||||
retention-days: 1
|
||||
|
||||
# Phase 2: Collect all artifacts and Publish Release (Single Job)
|
||||
|
||||
113
build-release.sh
113
build-release.sh
@@ -136,14 +136,121 @@ if command -v x86_64-w64-mingw32-gcc &> /dev/null; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe"
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-amd64.exe"
|
||||
else
|
||||
echo -e "${RED} ❌ Windows 构建失败。${NC}"
|
||||
echo -e "${RED} ❌ Windows amd64 构建失败。${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 MinGW 工具 (x86_64-w64-mingw32-gcc),跳过 Windows 构建。${NC}"
|
||||
echo -e "${YELLOW} ⚠️ 未找到 MinGW 工具 (x86_64-w64-mingw32-gcc),跳过 Windows amd64 构建。${NC}"
|
||||
fi
|
||||
|
||||
# --- Windows ARM64 构建 ---
|
||||
echo -e "${GREEN}🪟 正在构建 Windows (arm64)...${NC}"
|
||||
if command -v aarch64-w64-mingw32-gcc &> /dev/null; then
|
||||
wails build -platform windows/arm64 -clean
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe"
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-arm64.exe"
|
||||
else
|
||||
echo -e "${RED} ❌ Windows arm64 构建失败。${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 未找到 MinGW ARM64 工具 (aarch64-w64-mingw32-gcc),跳过 Windows arm64 构建。${NC}"
|
||||
echo " 安装命令: brew install mingw-w64 (需要支持 ARM64 的版本)"
|
||||
fi
|
||||
|
||||
# --- Linux AMD64 构建 ---
|
||||
echo -e "${GREEN}🐧 正在构建 Linux (amd64)...${NC}"
|
||||
# 检测当前系统
|
||||
CURRENT_OS=$(uname -s)
|
||||
CURRENT_ARCH=$(uname -m)
|
||||
|
||||
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "x86_64" ]; then
|
||||
# 本机 Linux amd64,直接构建
|
||||
wails build -platform linux/amd64 -clean
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
# 打包为 tar.gz
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64"
|
||||
rm "${APP_NAME}-${VERSION}-linux-amd64"
|
||||
cd ..
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-amd64.tar.gz"
|
||||
else
|
||||
echo -e "${RED} ❌ Linux amd64 构建失败。${NC}"
|
||||
fi
|
||||
elif command -v x86_64-linux-gnu-gcc &> /dev/null; then
|
||||
# macOS 或其他系统,尝试交叉编译
|
||||
export CC=x86_64-linux-gnu-gcc
|
||||
export CXX=x86_64-linux-gnu-g++
|
||||
export CGO_ENABLED=1
|
||||
wails build -platform linux/amd64 -clean
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64"
|
||||
rm "${APP_NAME}-${VERSION}-linux-amd64"
|
||||
cd ..
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-amd64.tar.gz"
|
||||
else
|
||||
echo -e "${RED} ❌ Linux amd64 交叉编译失败。${NC}"
|
||||
fi
|
||||
unset CC CXX CGO_ENABLED
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 非 Linux 系统且未找到交叉编译工具,跳过 Linux amd64 构建。${NC}"
|
||||
echo " 在 Linux 上运行此脚本可直接构建,或安装交叉编译工具链。"
|
||||
fi
|
||||
|
||||
# --- Linux ARM64 构建 ---
|
||||
echo -e "${GREEN}🐧 正在构建 Linux (arm64)...${NC}"
|
||||
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "aarch64" ]; then
|
||||
# 本机 Linux arm64,直接构建
|
||||
wails build -platform linux/arm64 -clean
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
rm "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
cd ..
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-arm64.tar.gz"
|
||||
else
|
||||
echo -e "${RED} ❌ Linux arm64 构建失败。${NC}"
|
||||
fi
|
||||
elif command -v aarch64-linux-gnu-gcc &> /dev/null; then
|
||||
# 交叉编译
|
||||
export CC=aarch64-linux-gnu-gcc
|
||||
export CXX=aarch64-linux-gnu-g++
|
||||
export CGO_ENABLED=1
|
||||
wails build -platform linux/arm64 -clean
|
||||
if [ $? -eq 0 ]; then
|
||||
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
|
||||
cd "$DIST_DIR"
|
||||
tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
rm "${APP_NAME}-${VERSION}-linux-arm64"
|
||||
cd ..
|
||||
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-arm64.tar.gz"
|
||||
else
|
||||
echo -e "${RED} ❌ Linux arm64 交叉编译失败。${NC}"
|
||||
fi
|
||||
unset CC CXX CGO_ENABLED
|
||||
else
|
||||
echo -e "${YELLOW} ⚠️ 非 Linux ARM64 系统且未找到交叉编译工具,跳过 Linux arm64 构建。${NC}"
|
||||
echo " 安装命令 (Ubuntu): sudo apt install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu"
|
||||
echo " 安装命令 (macOS): brew install aarch64-linux-gnu-gcc (需要第三方 tap)"
|
||||
fi
|
||||
|
||||
# 清理中间构建目录
|
||||
rm -rf "build/bin"
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}🎉 所有任务完成!构建产物在 'dist/' 目录下:${NC}"
|
||||
ls -1 "$DIST_DIR"
|
||||
ls -lh "$DIST_DIR"
|
||||
echo ""
|
||||
echo -e "${GREEN}📋 支持的平台:${NC}"
|
||||
echo " • macOS (Intel/Apple Silicon): .dmg"
|
||||
echo " • Windows (x64/ARM64): .exe"
|
||||
echo " • Linux (x64/ARM64): .tar.gz"
|
||||
echo ""
|
||||
echo -e "${YELLOW}💡 提示:Linux AppImage 包请使用 GitHub Actions CI/CD 构建。${NC}"
|
||||
|
||||
@@ -1 +1 @@
|
||||
d0f9366af59a6367ad3c7e2d4185ead4
|
||||
5b8157374dae5f9340e31b2d0bd2c00e
|
||||
@@ -285,12 +285,12 @@ function App() {
|
||||
title="拖动调整宽度"
|
||||
/>
|
||||
</Sider>
|
||||
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<div style={{ flex: 1, overflow: 'hidden' }}>
|
||||
<TabManager />
|
||||
</div>
|
||||
{isLogPanelOpen && (
|
||||
<LogPanel
|
||||
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<TabManager />
|
||||
</div>
|
||||
{isLogPanelOpen && (
|
||||
<LogPanel
|
||||
height={logPanelHeight}
|
||||
onClose={() => setIsLogPanelOpen(false)}
|
||||
onResizeStart={handleLogResizeStart}
|
||||
@@ -343,4 +343,4 @@ function App() {
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
export default App;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal, Form, Input, InputNumber, Button, message, Checkbox, Divider, Select, Alert, Card, Row, Col, Typography, Collapse } from 'antd';
|
||||
import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined } from '@ant-design/icons';
|
||||
import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined, CloudOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { DBConnect, DBGetDatabases, TestConnection } from '../../wailsjs/go/app/App';
|
||||
import { DBConnect, DBGetDatabases, TestConnection, RedisConnect } from '../../wailsjs/go/app/App';
|
||||
import { SavedConnection } from '../types';
|
||||
|
||||
const { Meta } = Card;
|
||||
@@ -16,6 +16,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
const [step, setStep] = useState(1); // 1: Select Type, 2: Configure
|
||||
const [testResult, setTestResult] = useState<{ type: 'success' | 'error', message: string } | null>(null);
|
||||
const [dbList, setDbList] = useState<string[]>([]);
|
||||
const [redisDbList, setRedisDbList] = useState<number[]>([]); // Redis databases 0-15
|
||||
const addConnection = useStore((state) => state.addConnection);
|
||||
const updateConnection = useStore((state) => state.updateConnection);
|
||||
|
||||
@@ -23,6 +24,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
if (open) {
|
||||
setTestResult(null); // Reset test result
|
||||
setDbList([]);
|
||||
setRedisDbList([]);
|
||||
if (initialValues) {
|
||||
// Edit mode: Go directly to step 2
|
||||
setStep(2);
|
||||
@@ -35,6 +37,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
password: initialValues.config.password,
|
||||
database: initialValues.config.database,
|
||||
includeDatabases: initialValues.includeDatabases,
|
||||
includeRedisDatabases: initialValues.includeRedisDatabases,
|
||||
useSSH: initialValues.config.useSSH,
|
||||
sshHost: initialValues.config.ssh?.host,
|
||||
sshPort: initialValues.config.ssh?.port,
|
||||
@@ -47,6 +50,10 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
});
|
||||
setUseSSH(initialValues.config.useSSH || false);
|
||||
setDbType(initialValues.config.type);
|
||||
// 如果是 Redis 编辑模式,设置已保存的 Redis 数据库列表
|
||||
if (initialValues.config.type === 'redis') {
|
||||
setRedisDbList(Array.from({ length: 16 }, (_, i) => i));
|
||||
}
|
||||
} else {
|
||||
// Create mode: Start at step 1
|
||||
setStep(1);
|
||||
@@ -61,18 +68,24 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
try {
|
||||
const values = await form.validateFields();
|
||||
setLoading(true);
|
||||
|
||||
|
||||
const config = await buildConfig(values);
|
||||
|
||||
const res = await DBConnect(config as any);
|
||||
|
||||
// Use different API for Redis
|
||||
const isRedisType = values.type === 'redis';
|
||||
const res = isRedisType
|
||||
? await RedisConnect(config as any)
|
||||
: await DBConnect(config as any);
|
||||
|
||||
setLoading(false);
|
||||
|
||||
|
||||
if (res.success) {
|
||||
const newConn = {
|
||||
id: initialValues ? initialValues.id : Date.now().toString(),
|
||||
name: values.name || (values.type === 'sqlite' ? 'SQLite DB' : values.host),
|
||||
name: values.name || (values.type === 'sqlite' ? 'SQLite DB' : (values.type === 'redis' ? `Redis ${values.host}` : values.host)),
|
||||
config: config,
|
||||
includeDatabases: values.includeDatabases
|
||||
includeDatabases: values.includeDatabases,
|
||||
includeRedisDatabases: isRedisType ? values.includeRedisDatabases : undefined
|
||||
};
|
||||
|
||||
if (initialValues) {
|
||||
@@ -82,7 +95,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
addConnection(newConn);
|
||||
message.success('连接已保存!');
|
||||
}
|
||||
|
||||
|
||||
form.resetFields();
|
||||
setUseSSH(false);
|
||||
setDbType('mysql');
|
||||
@@ -102,14 +115,26 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
setLoading(true);
|
||||
setTestResult(null);
|
||||
const config = await buildConfig(values);
|
||||
const res = await TestConnection(config as any);
|
||||
|
||||
// Use different API for Redis
|
||||
const isRedisType = values.type === 'redis';
|
||||
const res = isRedisType
|
||||
? await RedisConnect(config as any)
|
||||
: await TestConnection(config as any);
|
||||
|
||||
setLoading(false);
|
||||
if (res.success) {
|
||||
setTestResult({ type: 'success', message: res.message });
|
||||
const dbRes = await DBGetDatabases(config as any);
|
||||
if (dbRes.success) {
|
||||
const dbs = (dbRes.data as any[]).map((row: any) => row.Database || row.database);
|
||||
setDbList(dbs);
|
||||
if (isRedisType) {
|
||||
// Redis: generate database list 0-15
|
||||
setRedisDbList(Array.from({ length: 16 }, (_, i) => i));
|
||||
} else {
|
||||
// Other databases: fetch database list
|
||||
const dbRes = await DBGetDatabases(config as any);
|
||||
if (dbRes.success) {
|
||||
const dbs = (dbRes.data as any[]).map((row: any) => row.Database || row.database);
|
||||
setDbList(dbs);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setTestResult({ type: 'error', message: "测试失败: " + res.message });
|
||||
@@ -128,7 +153,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
keyPath: values.sshKeyPath || ""
|
||||
} : { host: "", port: 22, user: "", password: "", keyPath: "" };
|
||||
|
||||
return {
|
||||
return {
|
||||
type: values.type,
|
||||
host: values.host || "",
|
||||
port: Number(values.port || 0),
|
||||
@@ -146,12 +171,13 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
const handleTypeSelect = (type: string) => {
|
||||
setDbType(type);
|
||||
form.setFieldsValue({ type: type });
|
||||
|
||||
|
||||
// Auto-fill default port
|
||||
let defaultPort = 3306;
|
||||
switch (type) {
|
||||
case 'mysql': defaultPort = 3306; break;
|
||||
case 'postgres': defaultPort = 5432; break;
|
||||
case 'redis': defaultPort = 6379; break;
|
||||
case 'oracle': defaultPort = 1521; break;
|
||||
case 'dameng': defaultPort = 5236; break;
|
||||
case 'kingbase': defaultPort = 54321; break;
|
||||
@@ -166,10 +192,12 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
|
||||
const isSqlite = dbType === 'sqlite';
|
||||
const isCustom = dbType === 'custom';
|
||||
const isRedis = dbType === 'redis';
|
||||
|
||||
const dbTypes = [
|
||||
{ key: 'mysql', name: 'MySQL', icon: <ConsoleSqlOutlined style={{ fontSize: 24, color: '#00758F' }} /> },
|
||||
{ key: 'postgres', name: 'PostgreSQL', icon: <DatabaseOutlined style={{ fontSize: 24, color: '#336791' }} /> },
|
||||
{ key: 'redis', name: 'Redis', icon: <CloudOutlined style={{ fontSize: 24, color: '#DC382D' }} /> },
|
||||
{ key: 'sqlite', name: 'SQLite', icon: <FileTextOutlined style={{ fontSize: 24, color: '#003B57' }} /> },
|
||||
{ key: 'oracle', name: 'Oracle', icon: <DatabaseOutlined style={{ fontSize: 24, color: '#F80000' }} /> },
|
||||
{ key: 'dameng', name: 'Dameng (达梦)', icon: <CloudServerOutlined style={{ fontSize: 24, color: '#1890ff' }} /> },
|
||||
@@ -235,7 +263,22 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!isSqlite && (
|
||||
{/* Redis specific: password only, no username */}
|
||||
{isRedis && (
|
||||
<>
|
||||
<Form.Item name="password" label="密码 (可选)">
|
||||
<Input.Password placeholder="Redis 密码(如果设置了 requirepass)" />
|
||||
</Form.Item>
|
||||
<Form.Item name="includeRedisDatabases" label="显示数据库 (留空显示全部)" help="连接测试成功后可选择">
|
||||
<Select mode="multiple" placeholder="选择显示的数据库 (0-15)" allowClear>
|
||||
{redisDbList.map(db => <Select.Option key={db} value={db}>db{db}</Select.Option>)}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Non-Redis, non-SQLite: username and password */}
|
||||
{!isSqlite && !isRedis && (
|
||||
<div style={{ display: 'flex', gap: 16 }}>
|
||||
<Form.Item name="user" label="用户名" rules={[{ required: true, message: '请输入用户名' }]} style={{ flex: 1 }}>
|
||||
<Input />
|
||||
@@ -245,8 +288,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
</Form.Item>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isSqlite && (
|
||||
|
||||
{!isSqlite && !isRedis && (
|
||||
<Form.Item name="includeDatabases" label="显示数据库 (留空显示全部)" help="连接测试成功后可选择">
|
||||
<Select mode="multiple" placeholder="选择显示的数据库" allowClear>
|
||||
{dbList.map(db => <Select.Option key={db} value={db}>{db}</Select.Option>)}
|
||||
@@ -264,8 +307,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
{useSSH && (
|
||||
<div style={{ padding: '12px', background: '#f5f5f5', borderRadius: 6, marginTop: 12 }}>
|
||||
<div style={{ display: 'flex', gap: 16 }}>
|
||||
<Form.Item name="sshHost" label="SSH 主机" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
|
||||
<Input placeholder="ssh.example.com" />
|
||||
<Form.Item name="sshHost" label="SSH 主机 (域名或IP)" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
|
||||
<Input placeholder="例如: ssh.example.com 或 192.168.1.100" />
|
||||
</Form.Item>
|
||||
<Form.Item name="sshPort" label="端口" rules={[{ required: useSSH, message: '请输入SSH端口' }]} style={{ width: 100 }}>
|
||||
<InputNumber style={{ width: '100%' }} />
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import React, { useState, useEffect, useRef, useContext, useMemo, useCallback } from 'react';
|
||||
import { Table, message, Input, Button, Dropdown, MenuProps, Form, Pagination, Select, Modal } from 'antd';
|
||||
import type { SortOrder } from 'antd/es/table/interface';
|
||||
import { ReloadOutlined, ImportOutlined, ExportOutlined, DownOutlined, PlusOutlined, DeleteOutlined, SaveOutlined, UndoOutlined, FilterOutlined, CloseOutlined, ConsoleSqlOutlined, FileTextOutlined, CopyOutlined, ClearOutlined } from '@ant-design/icons';
|
||||
import { ImportData, ExportTable, ExportData, ApplyChanges } from '../../wailsjs/go/app/App';
|
||||
import { ReloadOutlined, ImportOutlined, ExportOutlined, DownOutlined, PlusOutlined, DeleteOutlined, SaveOutlined, UndoOutlined, FilterOutlined, CloseOutlined, ConsoleSqlOutlined, FileTextOutlined, CopyOutlined, ClearOutlined, EditOutlined } from '@ant-design/icons';
|
||||
import Editor from '@monaco-editor/react';
|
||||
import { ImportData, ExportTable, ExportData, ExportQuery, ApplyChanges } from '../../wailsjs/go/app/App';
|
||||
import { useStore } from '../store';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import 'react-resizable/css/styles.css';
|
||||
import { buildWhereSQL, escapeLiteral, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
|
||||
|
||||
// 内部行标识字段:避免与真实业务字段(如 `key` 列)冲突。
|
||||
export const GONAVI_ROW_KEY = '__gonavi_row_key__';
|
||||
@@ -27,16 +29,47 @@ const formatCellValue = (val: any) => {
|
||||
return String(val);
|
||||
};
|
||||
|
||||
const toEditableText = (val: any): string => {
|
||||
if (val === null || val === undefined) return '';
|
||||
if (typeof val === 'string') return val;
|
||||
try {
|
||||
return JSON.stringify(val, null, 2);
|
||||
} catch {
|
||||
return String(val);
|
||||
}
|
||||
};
|
||||
|
||||
const toFormText = (val: any): string => {
|
||||
if (val === null || val === undefined) return '';
|
||||
if (typeof val === 'string') return normalizeDateTimeString(val);
|
||||
return toEditableText(val);
|
||||
};
|
||||
|
||||
const looksLikeJsonText = (text: string): boolean => {
|
||||
const raw = (text || '').trim();
|
||||
if (!raw) return false;
|
||||
const first = raw[0];
|
||||
const last = raw[raw.length - 1];
|
||||
return (first === '{' && last === '}') || (first === '[' && last === ']');
|
||||
};
|
||||
|
||||
// --- Resizable Header (Native Implementation) ---
|
||||
const ResizableTitle = (props: any) => {
|
||||
const { onResizeStart, width, ...restProps } = props;
|
||||
|
||||
if (!width) {
|
||||
return <th {...restProps} />;
|
||||
const nextStyle = { ...(restProps.style || {}) } as React.CSSProperties;
|
||||
if (width) {
|
||||
nextStyle.width = width;
|
||||
}
|
||||
|
||||
// 注意:virtual table 模式下,rc-table 会依赖 header cell 的 width 样式来渲染选择列。
|
||||
// 若这里丢失 width,可能导致左上角“全选”checkbox 不显示。
|
||||
if (!width || typeof onResizeStart !== 'function') {
|
||||
return <th {...restProps} style={nextStyle} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<th {...restProps} style={{ ...restProps.style, position: 'relative' }}>
|
||||
<th {...restProps} style={{ ...nextStyle, position: 'relative' }}>
|
||||
{restProps.children}
|
||||
<span
|
||||
className="react-resizable-handle"
|
||||
@@ -85,6 +118,7 @@ interface EditableCellProps {
|
||||
dataIndex: string;
|
||||
record: Item;
|
||||
handleSave: (record: Item) => void;
|
||||
focusCell?: (record: Item, dataIndex: string, title: React.ReactNode) => void;
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
@@ -95,6 +129,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
|
||||
dataIndex,
|
||||
record,
|
||||
handleSave,
|
||||
focusCell,
|
||||
...restProps
|
||||
}) => {
|
||||
const [editing, setEditing] = useState(false);
|
||||
@@ -139,7 +174,26 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
|
||||
);
|
||||
}
|
||||
|
||||
return <td {...restProps} onDoubleClick={editable ? toggleEdit : undefined}>{childNode}</td>;
|
||||
const handleDoubleClick = () => {
|
||||
if (!editable) return;
|
||||
toggleEdit();
|
||||
};
|
||||
|
||||
const handleClick = (e: React.MouseEvent) => {
|
||||
restProps?.onClick?.(e);
|
||||
if (!editable) return;
|
||||
if (typeof focusCell === 'function') focusCell(record, dataIndex, title);
|
||||
};
|
||||
|
||||
return (
|
||||
<td
|
||||
{...restProps}
|
||||
onClick={editable ? handleClick : restProps?.onClick}
|
||||
onDoubleClick={editable ? handleDoubleClick : restProps?.onDoubleClick}
|
||||
>
|
||||
{childNode}
|
||||
</td>
|
||||
);
|
||||
});
|
||||
|
||||
const ContextMenuRow = React.memo(({ children, record, ...props }: any) => {
|
||||
@@ -221,9 +275,23 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
}) => {
|
||||
const { connections } = useStore();
|
||||
const addSqlLog = useStore(state => state.addSqlLog);
|
||||
const darkMode = useStore(state => state.darkMode);
|
||||
const selectionColumnWidth = 46;
|
||||
const [form] = Form.useForm();
|
||||
const [modal, contextHolder] = Modal.useModal();
|
||||
const gridId = useMemo(() => `grid-${uuidv4()}`, []);
|
||||
const [cellEditorOpen, setCellEditorOpen] = useState(false);
|
||||
const [cellEditorValue, setCellEditorValue] = useState('');
|
||||
const [cellEditorIsJson, setCellEditorIsJson] = useState(false);
|
||||
const [cellEditorMeta, setCellEditorMeta] = useState<{ record: Item; dataIndex: string; title: string } | null>(null);
|
||||
const cellEditorApplyRef = useRef<((val: string) => void) | null>(null);
|
||||
const [activeCell, setActiveCell] = useState<{ rowKey: string; dataIndex: string; title: string } | null>(null);
|
||||
const [rowEditorOpen, setRowEditorOpen] = useState(false);
|
||||
const [rowEditorRowKey, setRowEditorRowKey] = useState<string>('');
|
||||
const rowEditorBaseRef = useRef<Record<string, string>>({});
|
||||
const rowEditorDisplayRef = useRef<Record<string, string>>({});
|
||||
const rowEditorNullColsRef = useRef<Set<string>>(new Set());
|
||||
const [rowEditorForm] = Form.useForm();
|
||||
|
||||
// Helper to export specific data
|
||||
const exportData = async (rows: any[], format: string) => {
|
||||
@@ -237,32 +305,66 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
|
||||
const [columnWidths, setColumnWidths] = useState<Record<string, number>>({});
|
||||
|
||||
const closeCellEditor = useCallback(() => {
|
||||
setCellEditorOpen(false);
|
||||
setCellEditorMeta(null);
|
||||
setCellEditorValue('');
|
||||
setCellEditorIsJson(false);
|
||||
cellEditorApplyRef.current = null;
|
||||
}, []);
|
||||
|
||||
const openCellEditor = useCallback((record: Item, dataIndex: string, title: React.ReactNode, onApplyValue?: (val: string) => void) => {
|
||||
if (!record || !dataIndex) return;
|
||||
const raw = record?.[dataIndex];
|
||||
const text = toEditableText(raw);
|
||||
const isJson = looksLikeJsonText(text);
|
||||
const titleText = typeof title === 'string' ? title : (typeof title === 'number' ? String(title) : String(dataIndex));
|
||||
|
||||
setCellEditorMeta({ record, dataIndex, title: titleText });
|
||||
setCellEditorValue(text);
|
||||
setCellEditorIsJson(isJson);
|
||||
setCellEditorOpen(true);
|
||||
cellEditorApplyRef.current = typeof onApplyValue === 'function' ? onApplyValue : null;
|
||||
}, []);
|
||||
|
||||
// Dynamic Height
|
||||
const [tableHeight, setTableHeight] = useState(500);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) return;
|
||||
|
||||
let rafId: number;
|
||||
const el = containerRef.current;
|
||||
if (!el) return;
|
||||
|
||||
let rafId: number | null = null;
|
||||
|
||||
const resizeObserver = new ResizeObserver(entries => {
|
||||
if (rafId !== null) cancelAnimationFrame(rafId);
|
||||
rafId = requestAnimationFrame(() => {
|
||||
for (let entry of entries) {
|
||||
// Use boundingClientRect for more accurate render size (including padding if any)
|
||||
const height = entry.contentRect.height;
|
||||
if (height < 50) return;
|
||||
// Subtract header (~42px) and a buffer
|
||||
const h = Math.max(100, height - 42);
|
||||
setTableHeight(h);
|
||||
}
|
||||
const target = (entries[0]?.target as HTMLElement | undefined) || containerRef.current;
|
||||
if (!target) return;
|
||||
|
||||
const height = target.getBoundingClientRect().height;
|
||||
if (!Number.isFinite(height) || height < 50) return;
|
||||
|
||||
const headerEl =
|
||||
(target.querySelector('.ant-table-header') as HTMLElement | null) ||
|
||||
(target.querySelector('.ant-table-thead') as HTMLElement | null);
|
||||
const rawHeaderHeight = headerEl ? headerEl.getBoundingClientRect().height : NaN;
|
||||
const headerHeight =
|
||||
Number.isFinite(rawHeaderHeight) && rawHeaderHeight >= 24 && rawHeaderHeight <= 120 ? rawHeaderHeight : 42;
|
||||
|
||||
// 留一点余量,避免底部(边框/滚动条)遮挡最后一行
|
||||
const extraBottom = 16;
|
||||
const nextHeight = Math.max(100, Math.floor(height - headerHeight - extraBottom));
|
||||
setTableHeight(nextHeight);
|
||||
});
|
||||
});
|
||||
|
||||
resizeObserver.observe(containerRef.current);
|
||||
|
||||
resizeObserver.observe(el);
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
cancelAnimationFrame(rafId);
|
||||
if (rafId !== null) cancelAnimationFrame(rafId);
|
||||
};
|
||||
}, []);
|
||||
|
||||
@@ -272,7 +374,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const [deletedRowKeys, setDeletedRowKeys] = useState<Set<string>>(new Set());
|
||||
|
||||
// Filter State
|
||||
const [filterConditions, setFilterConditions] = useState<{ id: number, column: string, op: string, value: string }[]>([]);
|
||||
const [filterConditions, setFilterConditions] = useState<{ id: number, column: string, op: string, value: string, value2?: string }[]>([]);
|
||||
const [nextFilterId, setNextFilterId] = useState(1);
|
||||
|
||||
const selectedRowKeysRef = useRef(selectedRowKeys);
|
||||
@@ -286,6 +388,14 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
setModifiedRows({});
|
||||
setDeletedRowKeys(new Set());
|
||||
setSelectedRowKeys([]);
|
||||
setActiveCell(null);
|
||||
setRowEditorOpen(false);
|
||||
setRowEditorRowKey('');
|
||||
rowEditorBaseRef.current = {};
|
||||
rowEditorDisplayRef.current = {};
|
||||
rowEditorNullColsRef.current = new Set();
|
||||
rowEditorForm.resetFields();
|
||||
closeCellEditor();
|
||||
form.resetFields();
|
||||
}, [tableName, dbName, connectionId]); // Reset on context change
|
||||
|
||||
@@ -440,6 +550,29 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
}
|
||||
}, [addedRows]);
|
||||
|
||||
const handleCellEditorSave = useCallback(() => {
|
||||
if (!cellEditorMeta) return;
|
||||
const apply = cellEditorApplyRef.current;
|
||||
if (apply) {
|
||||
apply(cellEditorValue);
|
||||
closeCellEditor();
|
||||
return;
|
||||
}
|
||||
const nextRow: any = { ...cellEditorMeta.record, [cellEditorMeta.dataIndex]: cellEditorValue };
|
||||
handleCellSave(nextRow);
|
||||
closeCellEditor();
|
||||
}, [cellEditorMeta, cellEditorValue, handleCellSave, closeCellEditor]);
|
||||
|
||||
const handleFormatJsonInEditor = useCallback(() => {
|
||||
if (!cellEditorIsJson) return;
|
||||
try {
|
||||
const obj = JSON.parse(cellEditorValue);
|
||||
setCellEditorValue(JSON.stringify(obj, null, 2));
|
||||
} catch (e: any) {
|
||||
message.error("JSON 格式无效:" + (e?.message || String(e)));
|
||||
}
|
||||
}, [cellEditorIsJson, cellEditorValue]);
|
||||
|
||||
// Merge Data for Display
|
||||
// 'displayData' already merges addedRows.
|
||||
// We need to merge modifiedRows into it for rendering.
|
||||
@@ -453,6 +586,110 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
});
|
||||
}, [displayData, modifiedRows]);
|
||||
|
||||
const focusCell = useCallback((record: Item, dataIndex: string, title: React.ReactNode) => {
|
||||
const k = record?.[GONAVI_ROW_KEY];
|
||||
if (k === undefined) return;
|
||||
const titleText = typeof title === 'string' ? title : (typeof title === 'number' ? String(title) : String(dataIndex));
|
||||
setActiveCell({ rowKey: rowKeyStr(k), dataIndex, title: titleText });
|
||||
}, [rowKeyStr]);
|
||||
|
||||
const closeRowEditor = useCallback(() => {
|
||||
setRowEditorOpen(false);
|
||||
setRowEditorRowKey('');
|
||||
rowEditorBaseRef.current = {};
|
||||
rowEditorDisplayRef.current = {};
|
||||
rowEditorNullColsRef.current = new Set();
|
||||
rowEditorForm.resetFields();
|
||||
}, [rowEditorForm]);
|
||||
|
||||
const openRowEditor = useCallback(() => {
|
||||
if (readOnly || !tableName) return;
|
||||
if (selectedRowKeys.length > 1) {
|
||||
message.info('一次只能编辑一行,请仅选择一行');
|
||||
return;
|
||||
}
|
||||
|
||||
const keyStr =
|
||||
selectedRowKeys.length === 1 ? rowKeyStr(selectedRowKeys[0]) : activeCell?.rowKey;
|
||||
if (!keyStr) {
|
||||
message.info('请先选择一行(勾选一行或点击任意单元格)');
|
||||
return;
|
||||
}
|
||||
|
||||
const displayRow = mergedDisplayData.find(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr);
|
||||
if (!displayRow) {
|
||||
message.error('未找到目标行,请刷新后重试');
|
||||
return;
|
||||
}
|
||||
|
||||
const baseRow =
|
||||
data.find(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr) ||
|
||||
addedRows.find(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr) ||
|
||||
displayRow;
|
||||
|
||||
const baseMap: Record<string, string> = {};
|
||||
const displayMap: Record<string, string> = {};
|
||||
const nullCols = new Set<string>();
|
||||
|
||||
columnNames.forEach((col) => {
|
||||
const baseVal = (baseRow as any)?.[col];
|
||||
const displayVal = (displayRow as any)?.[col];
|
||||
baseMap[col] = toFormText(baseVal);
|
||||
displayMap[col] = toFormText(displayVal);
|
||||
if (baseVal === null || baseVal === undefined) nullCols.add(col);
|
||||
});
|
||||
|
||||
rowEditorBaseRef.current = baseMap;
|
||||
rowEditorDisplayRef.current = displayMap;
|
||||
rowEditorNullColsRef.current = nullCols;
|
||||
|
||||
rowEditorForm.setFieldsValue(displayMap);
|
||||
setRowEditorRowKey(keyStr);
|
||||
setRowEditorOpen(true);
|
||||
}, [readOnly, tableName, selectedRowKeys, activeCell, mergedDisplayData, data, addedRows, columnNames, rowEditorForm, rowKeyStr]);
|
||||
|
||||
const openRowEditorFieldEditor = useCallback((dataIndex: string) => {
|
||||
if (!dataIndex) return;
|
||||
const val = rowEditorForm.getFieldValue(dataIndex);
|
||||
openCellEditor(
|
||||
{ [dataIndex]: val ?? '' },
|
||||
dataIndex,
|
||||
dataIndex,
|
||||
(nextVal) => rowEditorForm.setFieldsValue({ [dataIndex]: nextVal }),
|
||||
);
|
||||
}, [rowEditorForm, openCellEditor]);
|
||||
|
||||
const applyRowEditor = useCallback(() => {
|
||||
const keyStr = rowEditorRowKey;
|
||||
if (!keyStr) return;
|
||||
const values = rowEditorForm.getFieldsValue(true) || {};
|
||||
|
||||
const isAdded = addedRows.some(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr);
|
||||
if (isAdded) {
|
||||
setAddedRows(prev => prev.map(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr ? { ...r, ...values } : r));
|
||||
closeRowEditor();
|
||||
return;
|
||||
}
|
||||
|
||||
const baseMap = rowEditorBaseRef.current || {};
|
||||
const patch: Record<string, any> = {};
|
||||
columnNames.forEach((col) => {
|
||||
const nextVal = values[col];
|
||||
const nextStr = toFormText(nextVal);
|
||||
const baseStr = baseMap[col] ?? '';
|
||||
if (nextStr !== baseStr) patch[col] = nextStr;
|
||||
});
|
||||
|
||||
setModifiedRows(prev => {
|
||||
const next = { ...prev };
|
||||
if (Object.keys(patch).length === 0) delete next[keyStr];
|
||||
else next[keyStr] = patch;
|
||||
return next;
|
||||
});
|
||||
|
||||
closeRowEditor();
|
||||
}, [rowEditorRowKey, rowEditorForm, addedRows, columnNames, rowKeyStr, closeRowEditor]);
|
||||
|
||||
const columns = useMemo(() => {
|
||||
return columnNames.map(key => ({
|
||||
title: key,
|
||||
@@ -481,9 +718,13 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
dataIndex: col.dataIndex,
|
||||
title: col.title,
|
||||
handleSave: handleCellSave,
|
||||
focusCell,
|
||||
className: (activeCell && rowKeyStr(record?.[GONAVI_ROW_KEY]) === activeCell.rowKey && col.dataIndex === activeCell.dataIndex)
|
||||
? 'gonavi-active-cell'
|
||||
: undefined,
|
||||
}),
|
||||
};
|
||||
}), [columns, handleCellSave]);
|
||||
}), [columns, handleCellSave, openCellEditor, focusCell, activeCell, rowKeyStr]);
|
||||
|
||||
const handleAddRow = () => {
|
||||
const newKey = `new-${Date.now()}`;
|
||||
@@ -633,11 +874,98 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
copyToClipboard(lines.join('\n'));
|
||||
}, [getTargets, copyToClipboard]);
|
||||
|
||||
const buildConnConfig = useCallback(() => {
|
||||
if (!connectionId) return null;
|
||||
const conn = connections.find(c => c.id === connectionId);
|
||||
if (!conn) return null;
|
||||
return {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: conn.config.database || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
}, [connections, connectionId]);
|
||||
|
||||
const exportByQuery = useCallback(async (sql: string, format: string, defaultName: string) => {
|
||||
const config = buildConnConfig();
|
||||
if (!config) return;
|
||||
const hide = message.loading(`正在导出...`, 0);
|
||||
const res = await ExportQuery(config as any, dbName || '', sql, defaultName || 'export', format);
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success("导出成功");
|
||||
} else if (res.message !== "Cancelled") {
|
||||
message.error("导出失败: " + res.message);
|
||||
}
|
||||
}, [buildConnConfig, dbName]);
|
||||
|
||||
const buildPkWhereSql = useCallback((rows: any[], dbType: string) => {
|
||||
if (!tableName || pkColumns.length === 0) return '';
|
||||
const targets = (rows || []).filter(Boolean);
|
||||
if (targets.length === 0) return '';
|
||||
|
||||
const clauses: string[] = [];
|
||||
for (const r of targets) {
|
||||
const andParts: string[] = [];
|
||||
for (const pk of pkColumns) {
|
||||
const col = quoteIdentPart(dbType, pk);
|
||||
const v = r?.[pk];
|
||||
if (v === null || v === undefined) return '';
|
||||
andParts.push(`${col} = '${escapeLiteral(String(v))}'`);
|
||||
}
|
||||
if (andParts.length === pkColumns.length) {
|
||||
clauses.push(`(${andParts.join(' AND ')})`);
|
||||
}
|
||||
}
|
||||
if (clauses.length === 0) return '';
|
||||
return clauses.join(' OR ');
|
||||
}, [pkColumns, tableName]);
|
||||
|
||||
const buildCurrentPageSql = useCallback((dbType: string) => {
|
||||
if (!tableName || !pagination) return '';
|
||||
const whereSQL = buildWhereSQL(dbType, filterConditions);
|
||||
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
if (sortInfo && sortInfo.order) {
|
||||
sql += ` ORDER BY ${quoteIdentPart(dbType, sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
}
|
||||
const offset = (pagination.current - 1) * pagination.pageSize;
|
||||
sql += ` LIMIT ${pagination.pageSize} OFFSET ${offset}`;
|
||||
return sql;
|
||||
}, [tableName, pagination, filterConditions, sortInfo]);
|
||||
|
||||
// Context Menu Export
|
||||
const handleExportSelected = useCallback(async (format: string, record: any) => {
|
||||
const records = getTargets(record);
|
||||
await exportData(records, format);
|
||||
}, [getTargets]);
|
||||
if (!connectionId || !tableName) {
|
||||
await exportData(records, format);
|
||||
return;
|
||||
}
|
||||
|
||||
// 有未提交修改时,优先按界面数据导出,避免与数据库不一致。
|
||||
if (hasChanges) {
|
||||
message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。");
|
||||
await exportData(records, format);
|
||||
return;
|
||||
}
|
||||
|
||||
const config = buildConnConfig();
|
||||
if (!config) {
|
||||
await exportData(records, format);
|
||||
return;
|
||||
}
|
||||
|
||||
const dbType = config.type || '';
|
||||
const pkWhere = buildPkWhereSql(records, dbType);
|
||||
if (!pkWhere) {
|
||||
await exportData(records, format);
|
||||
return;
|
||||
}
|
||||
|
||||
const sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} WHERE ${pkWhere}`;
|
||||
await exportByQuery(sql, format, tableName || 'export');
|
||||
}, [getTargets, connectionId, tableName, hasChanges, exportData, buildConnConfig, buildPkWhereSql, exportByQuery]);
|
||||
|
||||
// Export
|
||||
const handleExport = async (format: string) => {
|
||||
@@ -646,7 +974,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
// 1. Export Selected
|
||||
if (selectedRowKeys.length > 0) {
|
||||
const selectedRows = displayData.filter(d => selectedRowKeys.includes(d?.[GONAVI_ROW_KEY]));
|
||||
await exportData(selectedRows, format);
|
||||
await handleExportSelected(format, selectedRows[0]);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -655,9 +983,8 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
let instance: any;
|
||||
const handleAll = async () => {
|
||||
instance.destroy();
|
||||
const conn = connections.find(c => c.id === connectionId);
|
||||
if (!conn) return;
|
||||
const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } };
|
||||
const config = buildConnConfig();
|
||||
if (!config) return;
|
||||
const hide = message.loading(`正在导出全部数据...`, 0);
|
||||
const res = await ExportTable(config as any, dbName || '', tableName, format);
|
||||
hide();
|
||||
@@ -665,7 +992,25 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
};
|
||||
const handlePage = async () => {
|
||||
instance.destroy();
|
||||
await exportData(displayData, format);
|
||||
if (hasChanges) {
|
||||
message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。");
|
||||
await exportData(displayData, format);
|
||||
return;
|
||||
}
|
||||
|
||||
const config = buildConnConfig();
|
||||
if (!config) {
|
||||
await exportData(displayData, format);
|
||||
return;
|
||||
}
|
||||
|
||||
const sql = buildCurrentPageSql(config.type || '');
|
||||
if (!sql) {
|
||||
await exportData(displayData, format);
|
||||
return;
|
||||
}
|
||||
|
||||
await exportByQuery(sql, format, tableName || 'export');
|
||||
};
|
||||
|
||||
instance = modal.info({
|
||||
@@ -688,21 +1033,64 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const handleImport = async () => {
|
||||
if (!connectionId || !tableName) return;
|
||||
const conn = connections.find(c => c.id === connectionId);
|
||||
if (!conn) return;
|
||||
const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } };
|
||||
const config = buildConnConfig();
|
||||
if (!config) return;
|
||||
|
||||
const res = await ImportData(config as any, dbName || '', tableName);
|
||||
if (res.success) { message.success(res.message); if (onReload) onReload(); } else if (res.message !== "Cancelled") { message.error("Import Failed: " + res.message); }
|
||||
};
|
||||
|
||||
// Filters
|
||||
const filterOpOptions = useMemo(() => ([
|
||||
{ value: '=', label: '=' },
|
||||
{ value: '!=', label: '!=' },
|
||||
{ value: '<', label: '<' },
|
||||
{ value: '<=', label: '<=' },
|
||||
{ value: '>', label: '>' },
|
||||
{ value: '>=', label: '>=' },
|
||||
{ value: 'CONTAINS', label: '包含' },
|
||||
{ value: 'NOT_CONTAINS', label: '不包含' },
|
||||
{ value: 'STARTS_WITH', label: '开始以' },
|
||||
{ value: 'NOT_STARTS_WITH', label: '不是开始于' },
|
||||
{ value: 'ENDS_WITH', label: '结束以' },
|
||||
{ value: 'NOT_ENDS_WITH', label: '不是结束于' },
|
||||
{ value: 'IS_NULL', label: '是 null' },
|
||||
{ value: 'IS_NOT_NULL', label: '不是 null' },
|
||||
{ value: 'IS_EMPTY', label: '是空的' },
|
||||
{ value: 'IS_NOT_EMPTY', label: '不是空的' },
|
||||
{ value: 'BETWEEN', label: '介于' },
|
||||
{ value: 'NOT_BETWEEN', label: '不介于' },
|
||||
{ value: 'IN', label: '在列表' },
|
||||
{ value: 'NOT_IN', label: '不在列表' },
|
||||
{ value: 'CUSTOM', label: '[自定义]' },
|
||||
]), []);
|
||||
|
||||
const isNoValueOp = useCallback((op: string) => (
|
||||
op === 'IS_NULL' || op === 'IS_NOT_NULL' || op === 'IS_EMPTY' || op === 'IS_NOT_EMPTY'
|
||||
), []);
|
||||
const isBetweenOp = useCallback((op: string) => op === 'BETWEEN' || op === 'NOT_BETWEEN', []);
|
||||
const isListOp = useCallback((op: string) => op === 'IN' || op === 'NOT_IN', []);
|
||||
|
||||
const addFilter = () => {
|
||||
setFilterConditions([...filterConditions, { id: nextFilterId, column: columnNames[0] || '', op: '=', value: '' }]);
|
||||
setFilterConditions([...filterConditions, { id: nextFilterId, column: columnNames[0] || '', op: '=', value: '', value2: '' }]);
|
||||
setNextFilterId(nextFilterId + 1);
|
||||
};
|
||||
const updateFilter = (id: number, field: string, val: string) => {
|
||||
setFilterConditions(prev => prev.map(c => c.id === id ? { ...c, [field]: val } : c));
|
||||
setFilterConditions(prev => prev.map(c => {
|
||||
if (c.id !== id) return c;
|
||||
const next: any = { ...c, [field]: val };
|
||||
if (field === 'op') {
|
||||
if (isNoValueOp(val)) {
|
||||
next.value = '';
|
||||
next.value2 = '';
|
||||
} else if (isBetweenOp(val)) {
|
||||
if (typeof next.value2 !== 'string') next.value2 = '';
|
||||
} else {
|
||||
next.value2 = '';
|
||||
}
|
||||
}
|
||||
return next;
|
||||
}));
|
||||
};
|
||||
const removeFilter = (id: number) => {
|
||||
setFilterConditions(prev => prev.filter(c => c.id !== id));
|
||||
@@ -723,33 +1111,41 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
header: { cell: ResizableTitle }
|
||||
}), []);
|
||||
|
||||
const totalWidth = columns.reduce((sum, col) => sum + (col.width as number || 200), 0);
|
||||
const totalWidth = columns.reduce((sum, col) => sum + (Number(col.width) || 200), 0) + selectionColumnWidth;
|
||||
const enableVirtual = mergedDisplayData.length >= 200;
|
||||
|
||||
return (
|
||||
<div className={gridId} style={{ height: '100%', overflow: 'hidden', padding: 0, display: 'flex', flexDirection: 'column', minHeight: 0 }}>
|
||||
{/* Toolbar */}
|
||||
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: 8, alignItems: 'center' }}>
|
||||
{onReload && <Button icon={<ReloadOutlined />} onClick={() => {
|
||||
setAddedRows([]);
|
||||
setModifiedRows({});
|
||||
setDeletedRowKeys(new Set());
|
||||
setSelectedRowKeys([]);
|
||||
onReload();
|
||||
}}>刷新</Button>}
|
||||
{tableName && <Button icon={<ImportOutlined />} onClick={handleImport}>导入</Button>}
|
||||
{tableName && <Dropdown menu={{ items: exportMenu }}><Button icon={<ExportOutlined />}>导出 <DownOutlined /></Button></Dropdown>}
|
||||
|
||||
{!readOnly && tableName && (
|
||||
<>
|
||||
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
|
||||
<Button icon={<PlusOutlined />} onClick={handleAddRow}>添加行</Button>
|
||||
<Button icon={<DeleteOutlined />} danger disabled={selectedRowKeys.length === 0} onClick={handleDeleteSelected}>删除选中</Button>
|
||||
{selectedRowKeys.length > 0 && <span style={{ fontSize: '12px', color: '#888' }}>已选 {selectedRowKeys.length}</span>}
|
||||
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
|
||||
<Button icon={<SaveOutlined />} type="primary" disabled={!hasChanges} onClick={handleCommit}>提交事务 ({addedRows.length + Object.keys(modifiedRows).length + deletedRowKeys.size})</Button>
|
||||
{hasChanges && (<Button icon={<UndoOutlined />} onClick={() => {
|
||||
setAddedRows([]);
|
||||
<div className={gridId} style={{ flex: '1 1 auto', height: '100%', overflow: 'hidden', padding: 0, display: 'flex', flexDirection: 'column', minHeight: 0 }}>
|
||||
{/* Toolbar */}
|
||||
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: 8, alignItems: 'center' }}>
|
||||
{onReload && <Button icon={<ReloadOutlined />} onClick={() => {
|
||||
setAddedRows([]);
|
||||
setModifiedRows({});
|
||||
setDeletedRowKeys(new Set());
|
||||
setSelectedRowKeys([]);
|
||||
setActiveCell(null);
|
||||
onReload();
|
||||
}}>刷新</Button>}
|
||||
{tableName && <Button icon={<ImportOutlined />} onClick={handleImport}>导入</Button>}
|
||||
{tableName && <Dropdown menu={{ items: exportMenu }}><Button icon={<ExportOutlined />}>导出 <DownOutlined /></Button></Dropdown>}
|
||||
|
||||
{!readOnly && tableName && (
|
||||
<>
|
||||
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
|
||||
<Button icon={<PlusOutlined />} onClick={handleAddRow}>添加行</Button>
|
||||
<Button
|
||||
icon={<EditOutlined />}
|
||||
disabled={selectedRowKeys.length > 1 || (selectedRowKeys.length !== 1 && !activeCell)}
|
||||
onClick={openRowEditor}
|
||||
>
|
||||
编辑行
|
||||
</Button>
|
||||
<Button icon={<DeleteOutlined />} danger disabled={selectedRowKeys.length === 0} onClick={handleDeleteSelected}>删除选中</Button>
|
||||
{selectedRowKeys.length > 0 && <span style={{ fontSize: '12px', color: '#888' }}>已选 {selectedRowKeys.length}</span>}
|
||||
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
|
||||
<Button icon={<SaveOutlined />} type="primary" disabled={!hasChanges} onClick={handleCommit}>提交事务 ({addedRows.length + Object.keys(modifiedRows).length + deletedRowKeys.size})</Button>
|
||||
{hasChanges && (<Button icon={<UndoOutlined />} onClick={() => {
|
||||
setAddedRows([]);
|
||||
setModifiedRows({});
|
||||
setDeletedRowKeys(new Set());
|
||||
}}>回滚</Button>)}
|
||||
@@ -771,10 +1167,62 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
{showFilter && (
|
||||
<div style={{ padding: '8px', background: '#f5f5f5', borderBottom: '1px solid #eee' }}>
|
||||
{filterConditions.map(cond => (
|
||||
<div key={cond.id} style={{ display: 'flex', gap: 8, marginBottom: 8 }}>
|
||||
<Select style={{ width: 150 }} value={cond.column} onChange={v => updateFilter(cond.id, 'column', v)} options={columnNames.map(c => ({ value: c, label: c }))} />
|
||||
<Select style={{ width: 100 }} value={cond.op} onChange={v => updateFilter(cond.id, 'op', v)} options={[{ value: '=', label: '=' }, { value: 'LIKE', label: '包含' }]} />
|
||||
<Input style={{ width: 200 }} value={cond.value} onChange={e => updateFilter(cond.id, 'value', e.target.value)} />
|
||||
<div key={cond.id} style={{ display: 'flex', gap: 8, marginBottom: 8, alignItems: 'flex-start' }}>
|
||||
<Select
|
||||
style={{ width: 180 }}
|
||||
value={cond.column}
|
||||
onChange={v => updateFilter(cond.id, 'column', v)}
|
||||
options={columnNames.map(c => ({ value: c, label: c }))}
|
||||
disabled={cond.op === 'CUSTOM'}
|
||||
/>
|
||||
<Select
|
||||
style={{ width: 140 }}
|
||||
value={cond.op}
|
||||
onChange={v => updateFilter(cond.id, 'op', v)}
|
||||
options={filterOpOptions as any}
|
||||
/>
|
||||
|
||||
{cond.op === 'CUSTOM' ? (
|
||||
<Input.TextArea
|
||||
style={{ flex: 1 }}
|
||||
autoSize={{ minRows: 1, maxRows: 4 }}
|
||||
value={cond.value}
|
||||
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
|
||||
placeholder="输入自定义 WHERE 表达式(不需要再写 WHERE),例如:status IN ('A','B')"
|
||||
/>
|
||||
) : isListOp(cond.op) ? (
|
||||
<Input.TextArea
|
||||
style={{ flex: 1 }}
|
||||
autoSize={{ minRows: 1, maxRows: 4 }}
|
||||
value={cond.value}
|
||||
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
|
||||
placeholder="多个值用逗号或换行分隔"
|
||||
/>
|
||||
) : isBetweenOp(cond.op) ? (
|
||||
<>
|
||||
<Input
|
||||
style={{ width: 220 }}
|
||||
value={cond.value}
|
||||
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
|
||||
placeholder="开始值"
|
||||
/>
|
||||
<Input
|
||||
style={{ width: 220 }}
|
||||
value={cond.value2 || ''}
|
||||
onChange={e => updateFilter(cond.id, 'value2', e.target.value)}
|
||||
placeholder="结束值"
|
||||
/>
|
||||
</>
|
||||
) : isNoValueOp(cond.op) ? (
|
||||
<Input style={{ width: 220 }} value="" disabled placeholder="无需输入值" />
|
||||
) : (
|
||||
<Input
|
||||
style={{ width: 280 }}
|
||||
value={cond.value}
|
||||
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Button icon={<CloseOutlined />} onClick={() => removeFilter(cond.id)} type="text" danger />
|
||||
</div>
|
||||
))}
|
||||
@@ -789,8 +1237,90 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div ref={containerRef} style={{ flex: 1, overflow: 'hidden', position: 'relative', minHeight: 0 }}>
|
||||
{contextHolder}
|
||||
<div ref={containerRef} style={{ flex: 1, overflow: 'hidden', position: 'relative', minHeight: 0 }}>
|
||||
{contextHolder}
|
||||
<Modal
|
||||
title="编辑行"
|
||||
open={rowEditorOpen}
|
||||
onCancel={closeRowEditor}
|
||||
width={980}
|
||||
destroyOnClose
|
||||
maskClosable={false}
|
||||
footer={[
|
||||
<Button key="cancel" onClick={closeRowEditor}>取消</Button>,
|
||||
<Button key="ok" type="primary" onClick={applyRowEditor}>应用</Button>,
|
||||
]}
|
||||
>
|
||||
<div style={{ marginBottom: 8, color: '#888', fontSize: 12, display: 'flex', justifyContent: 'space-between', gap: 8 }}>
|
||||
<span>{tableName ? `${tableName}` : ''}</span>
|
||||
<span>{rowEditorRowKey ? `rowKey: ${rowEditorRowKey}` : ''}</span>
|
||||
</div>
|
||||
<Form form={rowEditorForm} layout="vertical">
|
||||
<div style={{ maxHeight: '62vh', overflow: 'auto', paddingRight: 8 }}>
|
||||
{columnNames.map((col) => {
|
||||
const sample = rowEditorDisplayRef.current?.[col] ?? '';
|
||||
const placeholder = rowEditorNullColsRef.current?.has(col) ? '(NULL)' : undefined;
|
||||
const isJson = looksLikeJsonText(sample);
|
||||
const useArea = isJson || sample.includes('\n') || sample.length >= 160;
|
||||
|
||||
return (
|
||||
<Form.Item key={col} label={col} style={{ marginBottom: 12 }}>
|
||||
<div style={{ display: 'flex', gap: 8, alignItems: 'flex-start' }}>
|
||||
<Form.Item name={col} noStyle>
|
||||
{useArea ? (
|
||||
<Input.TextArea
|
||||
style={{ flex: 1 }}
|
||||
autoSize={{ minRows: isJson ? 4 : 1, maxRows: 10 }}
|
||||
placeholder={placeholder}
|
||||
/>
|
||||
) : (
|
||||
<Input style={{ flex: 1 }} placeholder={placeholder} />
|
||||
)}
|
||||
</Form.Item>
|
||||
<Button size="small" onClick={() => openRowEditorFieldEditor(col)} title="弹窗编辑">...</Button>
|
||||
</div>
|
||||
</Form.Item>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</Form>
|
||||
</Modal>
|
||||
<Modal
|
||||
title={cellEditorMeta ? `编辑单元格:${cellEditorMeta.title}` : '编辑单元格'}
|
||||
open={cellEditorOpen}
|
||||
onCancel={closeCellEditor}
|
||||
width={960}
|
||||
destroyOnClose
|
||||
maskClosable={false}
|
||||
footer={[
|
||||
<Button key="format" onClick={handleFormatJsonInEditor} disabled={!cellEditorIsJson}>
|
||||
格式化 JSON
|
||||
</Button>,
|
||||
<Button key="cancel" onClick={closeCellEditor}>取消</Button>,
|
||||
<Button key="ok" type="primary" onClick={handleCellEditorSave}>保存</Button>,
|
||||
]}
|
||||
>
|
||||
<div style={{ marginBottom: 8, color: '#888', fontSize: 12 }}>
|
||||
{cellEditorMeta ? `${tableName || ''}${tableName ? '.' : ''}${cellEditorMeta.dataIndex}` : ''}
|
||||
</div>
|
||||
{cellEditorOpen && (
|
||||
<Editor
|
||||
height="56vh"
|
||||
language={cellEditorIsJson ? "json" : "plaintext"}
|
||||
theme={darkMode ? "vs-dark" : "light"}
|
||||
value={cellEditorValue}
|
||||
onChange={(val) => setCellEditorValue(val || '')}
|
||||
options={{
|
||||
minimap: { enabled: false },
|
||||
scrollBeyondLastLine: false,
|
||||
wordWrap: "on",
|
||||
fontSize: 14,
|
||||
tabSize: 2,
|
||||
automaticLayout: true,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Modal>
|
||||
<Form component={false} form={form}>
|
||||
<DataContext.Provider value={{ selectedRowKeysRef, displayDataRef, handleCopyInsert, handleCopyJson, handleCopyCsv, handleExportSelected, copyToClipboard, tableName }}>
|
||||
<EditableContext.Provider value={form}>
|
||||
@@ -799,6 +1329,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
dataSource={mergedDisplayData}
|
||||
columns={mergedColumns}
|
||||
size="small"
|
||||
tableLayout="fixed"
|
||||
scroll={{ x: Math.max(totalWidth, 1000), y: tableHeight }}
|
||||
virtual={enableVirtual}
|
||||
loading={loading}
|
||||
@@ -809,6 +1340,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
rowSelection={{
|
||||
selectedRowKeys,
|
||||
onChange: setSelectedRowKeys,
|
||||
columnWidth: selectionColumnWidth,
|
||||
}}
|
||||
rowClassName={(record) => {
|
||||
const k = record?.[GONAVI_ROW_KEY];
|
||||
@@ -842,14 +1374,14 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<style>{`
|
||||
.${gridId} .row-added td { background-color: #f6ffed !important; }
|
||||
.${gridId} .row-modified td { background-color: #e6f7ff !important; }
|
||||
.${gridId} .ant-table-body {
|
||||
height: ${tableHeight}px !important;
|
||||
max-height: ${tableHeight}px !important;
|
||||
}
|
||||
`}</style>
|
||||
<style>{`
|
||||
.${gridId} .row-added td { background-color: #f6ffed !important; }
|
||||
.${gridId} .row-modified td { background-color: #e6f7ff !important; }
|
||||
.${gridId} td.gonavi-active-cell {
|
||||
outline: 2px solid #1677ff;
|
||||
outline-offset: -2px;
|
||||
}
|
||||
`}</style>
|
||||
|
||||
{/* Ghost Resize Line for Columns */}
|
||||
<div
|
||||
|
||||
@@ -4,6 +4,7 @@ import { TabData, ColumnDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { buildWhereSQL, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
|
||||
|
||||
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [data, setData] = useState<any[]>([]);
|
||||
@@ -14,6 +15,8 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const fetchSeqRef = useRef(0);
|
||||
const countSeqRef = useRef(0);
|
||||
const countKeyRef = useRef<string>('');
|
||||
const pkSeqRef = useRef(0);
|
||||
const pkKeyRef = useRef<string>('');
|
||||
|
||||
const [pagination, setPagination] = useState({
|
||||
current: 1,
|
||||
@@ -27,6 +30,13 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [showFilter, setShowFilter] = useState(false);
|
||||
const [filterConditions, setFilterConditions] = useState<any[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
setPkColumns([]);
|
||||
pkKeyRef.current = '';
|
||||
countKeyRef.current = '';
|
||||
setPagination(prev => ({ ...prev, current: 1, total: 0, totalKnown: false }));
|
||||
}, [tab.connectionId, tab.dbName, tab.tableName]);
|
||||
|
||||
const fetchData = useCallback(async (page = pagination.current, size = pagination.pageSize) => {
|
||||
const seq = ++fetchSeqRef.current;
|
||||
setLoading(true);
|
||||
@@ -46,40 +56,18 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const quoteIdentPart = (ident: string) => {
|
||||
if (!ident) return ident;
|
||||
if (config.type === 'mysql') return `\`${ident.replace(/`/g, '``')}\``;
|
||||
return `"${ident.replace(/"/g, '""')}"`;
|
||||
};
|
||||
const quoteQualifiedIdent = (ident: string) => {
|
||||
const raw = (ident || '').trim();
|
||||
if (!raw) return raw;
|
||||
const parts = raw.split('.').filter(Boolean);
|
||||
if (parts.length <= 1) return quoteIdentPart(raw);
|
||||
return parts.map(quoteIdentPart).join('.');
|
||||
};
|
||||
const escapeLiteral = (val: string) => val.replace(/'/g, "''");
|
||||
const dbType = config.type || '';
|
||||
|
||||
const dbName = tab.dbName || '';
|
||||
const tableName = tab.tableName || '';
|
||||
|
||||
const whereParts: string[] = [];
|
||||
filterConditions.forEach(cond => {
|
||||
if (cond.column && cond.value) {
|
||||
if (cond.op === 'LIKE') {
|
||||
whereParts.push(`${quoteIdentPart(cond.column)} LIKE '%${escapeLiteral(cond.value)}%'`);
|
||||
} else {
|
||||
whereParts.push(`${quoteIdentPart(cond.column)} ${cond.op} '${escapeLiteral(cond.value)}'`);
|
||||
}
|
||||
}
|
||||
});
|
||||
const whereSQL = whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : "";
|
||||
const whereSQL = buildWhereSQL(dbType, filterConditions);
|
||||
|
||||
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(tableName)} ${whereSQL}`;
|
||||
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
|
||||
let sql = `SELECT * FROM ${quoteQualifiedIdent(tableName)} ${whereSQL}`;
|
||||
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
if (sortInfo && sortInfo.order) {
|
||||
sql += ` ORDER BY ${quoteIdentPart(sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
sql += ` ORDER BY ${quoteIdentPart(dbType, sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
}
|
||||
const offset = (page - 1) * size;
|
||||
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
|
||||
@@ -89,11 +77,6 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
try {
|
||||
const pData = DBQuery(config as any, dbName, sql);
|
||||
|
||||
let pCols: Promise<any> | null = null;
|
||||
if (pkColumns.length === 0) {
|
||||
pCols = DBGetColumns(config as any, dbName, tableName);
|
||||
}
|
||||
|
||||
const resData = await pData;
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
@@ -109,11 +92,23 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
dbName
|
||||
});
|
||||
|
||||
if (pCols) {
|
||||
const resCols = await pCols;
|
||||
if (resCols.success) {
|
||||
const pks = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
|
||||
setPkColumns(pks);
|
||||
if (pkColumns.length === 0) {
|
||||
const pkKey = `${tab.connectionId}|${dbName}|${tableName}`;
|
||||
if (pkKeyRef.current !== pkKey) {
|
||||
pkKeyRef.current = pkKey;
|
||||
const pkSeq = ++pkSeqRef.current;
|
||||
DBGetColumns(config as any, dbName, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (pkSeqRef.current !== pkSeq) return;
|
||||
if (pkKeyRef.current !== pkKey) return;
|
||||
if (!resCols?.success) return;
|
||||
const pks = (resCols.data as ColumnDefinition[]).filter((c: any) => c.key === 'PRI').map((c: any) => c.name);
|
||||
setPkColumns(pks);
|
||||
})
|
||||
.catch(() => {
|
||||
if (pkSeqRef.current !== pkSeq) return;
|
||||
if (pkKeyRef.current !== pkKey) return;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,7 +222,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}, [tab, sortInfo, filterConditions]); // Initial load and re-load on sort/filter
|
||||
|
||||
return (
|
||||
<div style={{ height: '100%', width: '100%', overflow: 'hidden' }}>
|
||||
<div style={{ flex: '1 1 auto', minHeight: 0, height: '100%', width: '100%', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<DataGrid
|
||||
data={data}
|
||||
columnNames={columnNames}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
205
frontend/src/components/RedisCommandEditor.tsx
Normal file
205
frontend/src/components/RedisCommandEditor.tsx
Normal file
@@ -0,0 +1,205 @@
|
||||
import React, { useState, useCallback, useRef } from 'react';
|
||||
import { Button, Space, message } from 'antd';
|
||||
import { PlayCircleOutlined, ClearOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import Editor, { OnMount } from '@monaco-editor/react';
|
||||
|
||||
interface RedisCommandEditorProps {
|
||||
connectionId: string;
|
||||
redisDB: number;
|
||||
}
|
||||
|
||||
interface CommandResult {
|
||||
command: string;
|
||||
result: any;
|
||||
error?: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
const RedisCommandEditor: React.FC<RedisCommandEditorProps> = ({ connectionId, redisDB }) => {
|
||||
const { connections } = useStore();
|
||||
const connection = connections.find(c => c.id === connectionId);
|
||||
|
||||
const [command, setCommand] = useState('');
|
||||
const [results, setResults] = useState<CommandResult[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const editorRef = useRef<any>(null);
|
||||
|
||||
const getConfig = useCallback(() => {
|
||||
if (!connection) return null;
|
||||
return {
|
||||
...connection.config,
|
||||
port: Number(connection.config.port),
|
||||
password: connection.config.password || "",
|
||||
useSSH: connection.config.useSSH || false,
|
||||
ssh: connection.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
|
||||
redisDB: redisDB
|
||||
};
|
||||
}, [connection, redisDB]);
|
||||
|
||||
const handleEditorMount: OnMount = (editor) => {
|
||||
editorRef.current = editor;
|
||||
// Add keyboard shortcut for execute
|
||||
editor.addCommand(
|
||||
// Ctrl/Cmd + Enter
|
||||
2048 | 3, // KeyMod.CtrlCmd | KeyCode.Enter
|
||||
() => handleExecute()
|
||||
);
|
||||
};
|
||||
|
||||
const handleExecute = async () => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
|
||||
const cmdToExecute = command.trim();
|
||||
if (!cmdToExecute) {
|
||||
message.warning('请输入命令');
|
||||
return;
|
||||
}
|
||||
|
||||
// Support multiple commands separated by newlines
|
||||
const commands = cmdToExecute.split('\n').filter(c => c.trim() && !c.trim().startsWith('//') && !c.trim().startsWith('#'));
|
||||
|
||||
setLoading(true);
|
||||
const newResults: CommandResult[] = [];
|
||||
|
||||
for (const cmd of commands) {
|
||||
const trimmedCmd = cmd.trim();
|
||||
if (!trimmedCmd) continue;
|
||||
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisExecuteCommand(config, trimmedCmd);
|
||||
newResults.push({
|
||||
command: trimmedCmd,
|
||||
result: res.success ? res.data : null,
|
||||
error: res.success ? undefined : res.message,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
} catch (e: any) {
|
||||
newResults.push({
|
||||
command: trimmedCmd,
|
||||
result: null,
|
||||
error: e?.message || String(e),
|
||||
timestamp: Date.now()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
setResults(prev => [...newResults, ...prev]);
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const handleClear = () => {
|
||||
setResults([]);
|
||||
};
|
||||
|
||||
const formatResult = (result: any): string => {
|
||||
if (result === null || result === undefined) {
|
||||
return '(nil)';
|
||||
}
|
||||
if (typeof result === 'string') {
|
||||
return `"${result}"`;
|
||||
}
|
||||
if (typeof result === 'number') {
|
||||
return `(integer) ${result}`;
|
||||
}
|
||||
if (Array.isArray(result)) {
|
||||
if (result.length === 0) {
|
||||
return '(empty array)';
|
||||
}
|
||||
return result.map((item, index) => `${index + 1}) ${formatResult(item)}`).join('\n');
|
||||
}
|
||||
if (typeof result === 'object') {
|
||||
return JSON.stringify(result, null, 2);
|
||||
}
|
||||
return String(result);
|
||||
};
|
||||
|
||||
if (!connection) {
|
||||
return <div style={{ padding: 20 }}>连接不存在</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
{/* Command Input */}
|
||||
<div style={{ borderBottom: '1px solid #f0f0f0' }}>
|
||||
<div style={{ padding: '8px 12px', borderBottom: '1px solid #f0f0f0', display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Space>
|
||||
<span style={{ fontWeight: 500 }}>Redis 命令</span>
|
||||
<span style={{ color: '#999', fontSize: 12 }}>db{redisDB}</span>
|
||||
</Space>
|
||||
<Space>
|
||||
<Button
|
||||
type="primary"
|
||||
icon={<PlayCircleOutlined />}
|
||||
onClick={handleExecute}
|
||||
loading={loading}
|
||||
>
|
||||
执行 (Ctrl+Enter)
|
||||
</Button>
|
||||
<Button icon={<ClearOutlined />} onClick={handleClear}>清空结果</Button>
|
||||
</Space>
|
||||
</div>
|
||||
<Editor
|
||||
height="150px"
|
||||
defaultLanguage="plaintext"
|
||||
value={command}
|
||||
onChange={(value) => setCommand(value || '')}
|
||||
onMount={handleEditorMount}
|
||||
options={{
|
||||
minimap: { enabled: false },
|
||||
lineNumbers: 'on',
|
||||
fontSize: 14,
|
||||
wordWrap: 'on',
|
||||
scrollBeyondLastLine: false,
|
||||
automaticLayout: true,
|
||||
tabSize: 2
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Results */}
|
||||
<div style={{ flex: 1, overflow: 'auto', background: '#1e1e1e', color: '#d4d4d4', fontFamily: 'monospace' }}>
|
||||
{results.length === 0 ? (
|
||||
<div style={{ padding: 20, color: '#666', textAlign: 'center' }}>
|
||||
输入 Redis 命令并按 Ctrl+Enter 执行
|
||||
<br />
|
||||
<span style={{ fontSize: 12 }}>支持多行命令,每行一个命令</span>
|
||||
</div>
|
||||
) : (
|
||||
results.map((item, index) => (
|
||||
<div key={item.timestamp + index} style={{ padding: '8px 12px', borderBottom: '1px solid #333' }}>
|
||||
<div style={{ color: '#569cd6', marginBottom: 4 }}>
|
||||
> {item.command}
|
||||
</div>
|
||||
{item.error ? (
|
||||
<div style={{ color: '#f14c4c', whiteSpace: 'pre-wrap' }}>
|
||||
(error) {item.error}
|
||||
</div>
|
||||
) : (
|
||||
<div style={{ color: '#ce9178', whiteSpace: 'pre-wrap' }}>
|
||||
{formatResult(item.result)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Common Commands Help */}
|
||||
<div style={{ padding: '8px 12px', borderTop: '1px solid #f0f0f0', background: '#fafafa', fontSize: 12, color: '#666' }}>
|
||||
常用命令:
|
||||
<span style={{ marginLeft: 8 }}>
|
||||
<code>KEYS *</code> |
|
||||
<code style={{ marginLeft: 8 }}>GET key</code> |
|
||||
<code style={{ marginLeft: 8 }}>SET key value</code> |
|
||||
<code style={{ marginLeft: 8 }}>HGETALL key</code> |
|
||||
<code style={{ marginLeft: 8 }}>INFO</code> |
|
||||
<code style={{ marginLeft: 8 }}>DBSIZE</code>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default RedisCommandEditor;
|
||||
1293
frontend/src/components/RedisViewer.tsx
Normal file
1293
frontend/src/components/RedisViewer.tsx
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,11 @@
|
||||
import React, { useEffect, useState, useMemo, useRef } from 'react';
|
||||
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge } from 'antd';
|
||||
import {
|
||||
DatabaseOutlined,
|
||||
TableOutlined,
|
||||
ConsoleSqlOutlined,
|
||||
HddOutlined,
|
||||
FolderOpenOutlined,
|
||||
DatabaseOutlined,
|
||||
TableOutlined,
|
||||
ConsoleSqlOutlined,
|
||||
HddOutlined,
|
||||
FolderOpenOutlined,
|
||||
FileTextOutlined,
|
||||
CopyOutlined,
|
||||
ExportOutlined,
|
||||
@@ -22,7 +22,8 @@ import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge }
|
||||
PlusOutlined,
|
||||
ReloadOutlined,
|
||||
DeleteOutlined,
|
||||
DisconnectOutlined
|
||||
DisconnectOutlined,
|
||||
CloudOutlined
|
||||
} from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { SavedConnection } from '../types';
|
||||
@@ -37,7 +38,7 @@ interface TreeNode {
|
||||
children?: TreeNode[];
|
||||
icon?: React.ReactNode;
|
||||
dataRef?: any;
|
||||
type?: 'connection' | 'database' | 'table' | 'queries-folder' | 'saved-query' | 'folder-columns' | 'folder-indexes' | 'folder-fks' | 'folder-triggers';
|
||||
type?: 'connection' | 'database' | 'table' | 'queries-folder' | 'saved-query' | 'folder-columns' | 'folder-indexes' | 'folder-fks' | 'folder-triggers' | 'redis-db';
|
||||
}
|
||||
|
||||
const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> = ({ onEditConnection }) => {
|
||||
@@ -47,6 +48,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
|
||||
const [autoExpandParent, setAutoExpandParent] = useState(true);
|
||||
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
|
||||
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
|
||||
const [selectedNodes, setSelectedNodes] = useState<any[]>([]);
|
||||
const [contextMenu, setContextMenu] = useState<{ x: number, y: number, items: MenuProps['items'] } | null>(null);
|
||||
|
||||
// Virtual Scroll State
|
||||
@@ -97,7 +100,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
setTreeData(connections.map(conn => ({
|
||||
title: conn.name,
|
||||
key: conn.id,
|
||||
icon: <HddOutlined />,
|
||||
icon: conn.config.type === 'redis' ? <CloudOutlined style={{ color: '#DC382D' }} /> : <HddOutlined />,
|
||||
type: 'connection',
|
||||
dataRef: conn,
|
||||
isLeaf: false,
|
||||
@@ -118,14 +121,46 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
|
||||
const loadDatabases = async (node: any) => {
|
||||
const conn = node.dataRef as SavedConnection;
|
||||
const config = {
|
||||
...conn.config,
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: conn.config.database || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
// Handle Redis connections differently
|
||||
if (conn.config.type === 'redis') {
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisGetDatabases(config);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
let dbs = (res.data as any[]).map((db: any) => ({
|
||||
title: `db${db.index}${db.keys > 0 ? ` (${db.keys})` : ''}`,
|
||||
key: `${conn.id}-db${db.index}`,
|
||||
icon: <DatabaseOutlined style={{ color: '#DC382D' }} />,
|
||||
type: 'redis-db' as const,
|
||||
dataRef: { ...conn, redisDB: db.index },
|
||||
isLeaf: true,
|
||||
dbIndex: db.index,
|
||||
}));
|
||||
// Filter Redis databases if configured
|
||||
if (conn.includeRedisDatabases && conn.includeRedisDatabases.length > 0) {
|
||||
dbs = dbs.filter(db => conn.includeRedisDatabases!.includes(db.dbIndex));
|
||||
}
|
||||
setTreeData(origin => updateTreeData(origin, node.key, dbs));
|
||||
} else {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
|
||||
message.error(res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
|
||||
message.error('连接失败: ' + (e?.message || String(e)));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const res = await DBGetDatabases(config as any);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
@@ -283,13 +318,17 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
|
||||
const onSelect = (keys: React.Key[], info: any) => {
|
||||
if (!info.node.selected) {
|
||||
setSelectedKeys(keys);
|
||||
setSelectedNodes(info.selectedNodes || []);
|
||||
|
||||
if (keys.length === 0) {
|
||||
setActiveContext(null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!info.selected) return;
|
||||
|
||||
const { type, dataRef, key, title } = info.node;
|
||||
|
||||
|
||||
// Update active context
|
||||
if (type === 'connection') {
|
||||
setActiveContext({ connectionId: key, dbName: '' });
|
||||
@@ -299,6 +338,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
setActiveContext({ connectionId: dataRef.id, dbName: dataRef.dbName });
|
||||
} else if (type === 'saved-query') {
|
||||
setActiveContext({ connectionId: dataRef.connectionId, dbName: dataRef.dbName });
|
||||
} else if (type === 'redis-db') {
|
||||
setActiveContext({ connectionId: dataRef.id, dbName: `db${dataRef.redisDB}` });
|
||||
}
|
||||
|
||||
if (type === 'folder-columns') openDesign(info.node, 'columns', true);
|
||||
@@ -313,15 +354,6 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
|
||||
const onDoubleClick = (e: any, node: any) => {
|
||||
const key = node.key;
|
||||
const isExpanded = expandedKeys.includes(key);
|
||||
const newExpandedKeys = isExpanded
|
||||
? expandedKeys.filter(k => k !== key)
|
||||
: [...expandedKeys, key];
|
||||
|
||||
setExpandedKeys(newExpandedKeys);
|
||||
if (!isExpanded) setAutoExpandParent(false);
|
||||
|
||||
if (node.type === 'table') {
|
||||
const { tableName, dbName, id } = node.dataRef;
|
||||
addTab({
|
||||
@@ -332,6 +364,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
dbName,
|
||||
tableName,
|
||||
});
|
||||
return;
|
||||
} else if (node.type === 'saved-query') {
|
||||
const q = node.dataRef;
|
||||
addTab({
|
||||
@@ -342,7 +375,27 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
dbName: q.dbName,
|
||||
query: q.sql
|
||||
});
|
||||
return;
|
||||
} else if (node.type === 'redis-db') {
|
||||
const { id, redisDB } = node.dataRef;
|
||||
addTab({
|
||||
id: `redis-keys-${id}-db${redisDB}`,
|
||||
title: `db${redisDB}`,
|
||||
type: 'redis-keys',
|
||||
connectionId: id,
|
||||
redisDB: redisDB
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const key = node.key;
|
||||
const isExpanded = expandedKeys.includes(key);
|
||||
const newExpandedKeys = isExpanded
|
||||
? expandedKeys.filter(k => k !== key)
|
||||
: [...expandedKeys, key];
|
||||
|
||||
setExpandedKeys(newExpandedKeys);
|
||||
if (!isExpanded) setAutoExpandParent(false);
|
||||
};
|
||||
|
||||
const handleCopyStructure = async (node: any) => {
|
||||
@@ -382,6 +435,60 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
};
|
||||
|
||||
const normalizeConnConfig = (raw: any) => ({
|
||||
...raw,
|
||||
port: Number(raw.port),
|
||||
password: raw.password || "",
|
||||
database: raw.database || "",
|
||||
useSSH: raw.useSSH || false,
|
||||
ssh: raw.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
});
|
||||
|
||||
const handleExportDatabaseSQL = async (node: any, includeData: boolean) => {
|
||||
const conn = node.dataRef;
|
||||
const dbName = conn.dbName || node.title;
|
||||
const hide = message.loading(includeData ? `正在备份数据库 ${dbName} (结构+数据)...` : `正在导出数据库 ${dbName} 表结构...`, 0);
|
||||
try {
|
||||
const res = await (window as any).go.app.App.ExportDatabaseSQL(normalizeConnConfig(conn.config), dbName, includeData);
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
hide();
|
||||
message.error('导出失败: ' + (e?.message || String(e)));
|
||||
}
|
||||
};
|
||||
|
||||
const handleExportTablesSQL = async (nodes: any[], includeData: boolean) => {
|
||||
if (!nodes || nodes.length === 0) return;
|
||||
const first = nodes[0].dataRef;
|
||||
const dbName = first.dbName;
|
||||
const connId = first.id;
|
||||
const allSame = nodes.every(n => n?.dataRef?.id === connId && n?.dataRef?.dbName === dbName);
|
||||
if (!allSame) {
|
||||
message.error('请在同一连接、同一数据库下选择多张表进行导出');
|
||||
return;
|
||||
}
|
||||
|
||||
const tableNames = nodes.map(n => n.dataRef.tableName).filter(Boolean);
|
||||
const hide = message.loading(includeData ? `正在备份选中表 (${tableNames.length})...` : `正在导出选中表结构 (${tableNames.length})...`, 0);
|
||||
try {
|
||||
const res = await (window as any).go.app.App.ExportTablesSQL(normalizeConnConfig(first.config), dbName, tableNames, includeData);
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
hide();
|
||||
message.error('导出失败: ' + (e?.message || String(e)));
|
||||
}
|
||||
};
|
||||
|
||||
const handleRunSQLFile = async (node: any) => {
|
||||
const res = await (window as any).go.app.App.OpenSQLFile();
|
||||
if (res.success) {
|
||||
@@ -457,7 +564,80 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}, [searchValue, treeData]);
|
||||
|
||||
const getNodeMenuItems = (node: any): MenuProps['items'] => {
|
||||
const conn = node.dataRef as SavedConnection;
|
||||
const isRedis = conn?.config?.type === 'redis';
|
||||
|
||||
if (node.type === 'connection') {
|
||||
// Redis connection menu
|
||||
if (isRedis) {
|
||||
return [
|
||||
{
|
||||
key: 'refresh',
|
||||
label: '刷新',
|
||||
icon: <ReloadOutlined />,
|
||||
onClick: () => loadDatabases(node)
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'new-command',
|
||||
label: '新建命令窗口',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => {
|
||||
addTab({
|
||||
id: `redis-cmd-${node.key}-${Date.now()}`,
|
||||
title: `命令 - ${node.title}`,
|
||||
type: 'redis-command',
|
||||
connectionId: node.key,
|
||||
redisDB: 0
|
||||
});
|
||||
}
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'edit',
|
||||
label: '编辑连接',
|
||||
icon: <EditOutlined />,
|
||||
onClick: () => {
|
||||
if (onEditConnection) onEditConnection(node.dataRef);
|
||||
}
|
||||
},
|
||||
{
|
||||
key: 'disconnect',
|
||||
label: '断开连接',
|
||||
icon: <DisconnectOutlined />,
|
||||
onClick: () => {
|
||||
setConnectionStates(prev => {
|
||||
const next = { ...prev };
|
||||
Object.keys(next).forEach(k => {
|
||||
if (k === node.key || k.startsWith(`${node.key}-`)) {
|
||||
delete next[k];
|
||||
}
|
||||
});
|
||||
return next;
|
||||
});
|
||||
setExpandedKeys(prev => prev.filter(k => k !== node.key && !k.toString().startsWith(`${node.key}-`)));
|
||||
setLoadedKeys(prev => prev.filter(k => k !== node.key && !k.toString().startsWith(`${node.key}-`)));
|
||||
setTreeData(origin => updateTreeData(origin, node.key, undefined));
|
||||
message.success("已断开连接");
|
||||
}
|
||||
},
|
||||
{
|
||||
key: 'delete',
|
||||
label: '删除连接',
|
||||
icon: <DeleteOutlined />,
|
||||
danger: true,
|
||||
onClick: () => {
|
||||
Modal.confirm({
|
||||
title: '确认删除',
|
||||
content: `确定要删除连接 "${node.title}" 吗?`,
|
||||
onOk: () => removeConnection(node.key)
|
||||
});
|
||||
}
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
// Regular database connection menu
|
||||
return [
|
||||
{
|
||||
key: 'new-db',
|
||||
@@ -475,9 +655,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
onClick: () => loadDatabases(node)
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'new-query',
|
||||
label: '新建查询',
|
||||
{
|
||||
key: 'new-query',
|
||||
label: '新建查询',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => {
|
||||
addTab({
|
||||
@@ -536,6 +716,39 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
}
|
||||
];
|
||||
} else if (node.type === 'redis-db') {
|
||||
// Redis database menu
|
||||
const { id, redisDB } = node.dataRef;
|
||||
return [
|
||||
{
|
||||
key: 'open-keys',
|
||||
label: '浏览 Key',
|
||||
icon: <KeyOutlined />,
|
||||
onClick: () => {
|
||||
addTab({
|
||||
id: `redis-keys-${id}-db${redisDB}`,
|
||||
title: `db${redisDB}`,
|
||||
type: 'redis-keys',
|
||||
connectionId: id,
|
||||
redisDB: redisDB
|
||||
});
|
||||
}
|
||||
},
|
||||
{
|
||||
key: 'new-command',
|
||||
label: '新建命令窗口',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => {
|
||||
addTab({
|
||||
id: `redis-cmd-${id}-db${redisDB}-${Date.now()}`,
|
||||
title: `命令 - db${redisDB}`,
|
||||
type: 'redis-command',
|
||||
connectionId: id,
|
||||
redisDB: redisDB
|
||||
});
|
||||
}
|
||||
}
|
||||
];
|
||||
} else if (node.type === 'database') {
|
||||
return [
|
||||
{
|
||||
@@ -550,6 +763,18 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
icon: <ReloadOutlined />,
|
||||
onClick: () => loadTables(node)
|
||||
},
|
||||
{
|
||||
key: 'export-db-schema',
|
||||
label: '导出全部表结构 (SQL)',
|
||||
icon: <ExportOutlined />,
|
||||
onClick: () => handleExportDatabaseSQL(node, false)
|
||||
},
|
||||
{
|
||||
key: 'backup-db-sql',
|
||||
label: '备份全部表 (结构+数据 SQL)',
|
||||
icon: <SaveOutlined />,
|
||||
onClick: () => handleExportDatabaseSQL(node, true)
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'disconnect-db',
|
||||
@@ -588,7 +813,25 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
];
|
||||
} else if (node.type === 'table') {
|
||||
const sameContextSelectedTables = (selectedNodes || []).filter((n: any) => n?.type === 'table' && n?.dataRef?.id === node?.dataRef?.id && n?.dataRef?.dbName === node?.dataRef?.dbName);
|
||||
const selectedForAction = sameContextSelectedTables.some((n: any) => n?.key === node.key) ? sameContextSelectedTables : [node];
|
||||
|
||||
return [
|
||||
...(selectedForAction.length > 1 ? ([
|
||||
{
|
||||
key: 'export-selected-schema',
|
||||
label: `导出选中表结构 (${selectedForAction.length}) (SQL)`,
|
||||
icon: <ExportOutlined />,
|
||||
onClick: () => handleExportTablesSQL(selectedForAction, false)
|
||||
},
|
||||
{
|
||||
key: 'backup-selected-sql',
|
||||
label: `备份选中表 (${selectedForAction.length}) (结构+数据 SQL)`,
|
||||
icon: <SaveOutlined />,
|
||||
onClick: () => handleExportTablesSQL(selectedForAction, true)
|
||||
},
|
||||
{ type: 'divider' as const }
|
||||
]) : []),
|
||||
{
|
||||
key: 'new-query',
|
||||
label: '新建查询',
|
||||
@@ -684,6 +927,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
loadedKeys={loadedKeys}
|
||||
onLoad={setLoadedKeys}
|
||||
autoExpandParent={autoExpandParent}
|
||||
multiple
|
||||
selectedKeys={selectedKeys}
|
||||
blockNode
|
||||
height={treeHeight}
|
||||
onRightClick={onRightClick}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { Tabs, Button } from 'antd';
|
||||
import { Tabs, Dropdown } from 'antd';
|
||||
import type { MenuProps } from 'antd';
|
||||
import { useStore } from '../store';
|
||||
import DataViewer from './DataViewer';
|
||||
import QueryEditor from './QueryEditor';
|
||||
import TableDesigner from './TableDesigner';
|
||||
import RedisViewer from './RedisViewer';
|
||||
import RedisCommandEditor from './RedisCommandEditor';
|
||||
|
||||
const TabManager: React.FC = () => {
|
||||
const { tabs, activeTabId, setActiveTab, closeTab } = useStore();
|
||||
const { tabs, activeTabId, setActiveTab, closeTab, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs } = useStore();
|
||||
|
||||
const onChange = (newActiveKey: string) => {
|
||||
setActiveTab(newActiveKey);
|
||||
@@ -18,7 +21,7 @@ const TabManager: React.FC = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const items = useMemo(() => tabs.map(tab => {
|
||||
const items = useMemo(() => tabs.map((tab, index) => {
|
||||
let content;
|
||||
if (tab.type === 'query') {
|
||||
content = <QueryEditor tab={tab} />;
|
||||
@@ -26,28 +29,100 @@ const TabManager: React.FC = () => {
|
||||
content = <DataViewer tab={tab} />;
|
||||
} else if (tab.type === 'design') {
|
||||
content = <TableDesigner tab={tab} />;
|
||||
} else if (tab.type === 'redis-keys') {
|
||||
content = <RedisViewer connectionId={tab.connectionId} redisDB={tab.redisDB ?? 0} />;
|
||||
} else if (tab.type === 'redis-command') {
|
||||
content = <RedisCommandEditor connectionId={tab.connectionId} redisDB={tab.redisDB ?? 0} />;
|
||||
}
|
||||
|
||||
const menuItems: MenuProps['items'] = [
|
||||
{
|
||||
key: 'close-other',
|
||||
label: '关闭其他页',
|
||||
disabled: tabs.length <= 1,
|
||||
onClick: () => closeOtherTabs(tab.id),
|
||||
},
|
||||
{
|
||||
key: 'close-left',
|
||||
label: '关闭左侧',
|
||||
disabled: index === 0,
|
||||
onClick: () => closeTabsToLeft(tab.id),
|
||||
},
|
||||
{
|
||||
key: 'close-right',
|
||||
label: '关闭右侧',
|
||||
disabled: index === tabs.length - 1,
|
||||
onClick: () => closeTabsToRight(tab.id),
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'close-all',
|
||||
label: '关闭所有',
|
||||
disabled: tabs.length === 0,
|
||||
onClick: () => closeAllTabs(),
|
||||
},
|
||||
];
|
||||
|
||||
return {
|
||||
label: tab.title,
|
||||
label: (
|
||||
<Dropdown menu={{ items: menuItems }} trigger={['contextMenu']}>
|
||||
<span onContextMenu={(e) => e.preventDefault()}>{tab.title}</span>
|
||||
</Dropdown>
|
||||
),
|
||||
key: tab.id,
|
||||
children: content,
|
||||
};
|
||||
}), [tabs]);
|
||||
}), [tabs, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<style>{`
|
||||
.ant-tabs-content { height: 100%; }
|
||||
.ant-tabs-tabpane { height: 100%; }
|
||||
.main-tabs {
|
||||
height: 100%;
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
.main-tabs .ant-tabs-nav {
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
.main-tabs .ant-tabs-content-holder {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.main-tabs .ant-tabs-content {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.main-tabs .ant-tabs-tabpane {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
.main-tabs .ant-tabs-tabpane > div {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
}
|
||||
.main-tabs .ant-tabs-tabpane-hidden {
|
||||
display: none !important;
|
||||
}
|
||||
`}</style>
|
||||
<Tabs
|
||||
className="main-tabs"
|
||||
type="editable-card"
|
||||
onChange={onChange}
|
||||
activeKey={activeTabId || undefined}
|
||||
onEdit={onEdit}
|
||||
items={items}
|
||||
style={{ height: '100%' }}
|
||||
hideAdd
|
||||
/>
|
||||
</>
|
||||
|
||||
@@ -550,7 +550,6 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
<div ref={containerRef} className="table-designer-wrapper" style={{ height: '100%', overflow: 'hidden', position: 'relative' }}>
|
||||
<style>{`
|
||||
.table-designer-wrapper .ant-table-body {
|
||||
height: ${tableHeight}px !important;
|
||||
max-height: ${tableHeight}px !important;
|
||||
}
|
||||
`}</style>
|
||||
|
||||
@@ -21,6 +21,7 @@ interface AppState {
|
||||
savedQueries: SavedQuery[];
|
||||
darkMode: boolean;
|
||||
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
|
||||
queryOptions: { maxRows: number };
|
||||
sqlLogs: SqlLog[];
|
||||
|
||||
addConnection: (conn: SavedConnection) => void;
|
||||
@@ -29,6 +30,10 @@ interface AppState {
|
||||
|
||||
addTab: (tab: TabData) => void;
|
||||
closeTab: (id: string) => void;
|
||||
closeOtherTabs: (id: string) => void;
|
||||
closeTabsToLeft: (id: string) => void;
|
||||
closeTabsToRight: (id: string) => void;
|
||||
closeAllTabs: () => void;
|
||||
setActiveTab: (id: string) => void;
|
||||
setActiveContext: (context: { connectionId: string; dbName: string } | null) => void;
|
||||
|
||||
@@ -37,6 +42,7 @@ interface AppState {
|
||||
|
||||
toggleDarkMode: () => void;
|
||||
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
|
||||
setQueryOptions: (options: Partial<{ maxRows: number }>) => void;
|
||||
|
||||
addSqlLog: (log: SqlLog) => void;
|
||||
clearSqlLogs: () => void;
|
||||
@@ -52,6 +58,7 @@ export const useStore = create<AppState>()(
|
||||
savedQueries: [],
|
||||
darkMode: false,
|
||||
sqlFormatOptions: { keywordCase: 'upper' },
|
||||
queryOptions: { maxRows: 5000 },
|
||||
sqlLogs: [],
|
||||
|
||||
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
|
||||
@@ -79,6 +86,30 @@ export const useStore = create<AppState>()(
|
||||
}
|
||||
return { tabs: newTabs, activeTabId: newActiveId };
|
||||
}),
|
||||
|
||||
closeOtherTabs: (id) => set((state) => {
|
||||
const keep = state.tabs.find(t => t.id === id);
|
||||
if (!keep) return state;
|
||||
return { tabs: [keep], activeTabId: id };
|
||||
}),
|
||||
|
||||
closeTabsToLeft: (id) => set((state) => {
|
||||
const index = state.tabs.findIndex(t => t.id === id);
|
||||
if (index === -1) return state;
|
||||
const newTabs = state.tabs.slice(index);
|
||||
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
|
||||
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
|
||||
}),
|
||||
|
||||
closeTabsToRight: (id) => set((state) => {
|
||||
const index = state.tabs.findIndex(t => t.id === id);
|
||||
if (index === -1) return state;
|
||||
const newTabs = state.tabs.slice(0, index + 1);
|
||||
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
|
||||
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
|
||||
}),
|
||||
|
||||
closeAllTabs: () => set(() => ({ tabs: [], activeTabId: null })),
|
||||
|
||||
setActiveTab: (id) => set({ activeTabId: id }),
|
||||
setActiveContext: (context) => set({ activeContext: context }),
|
||||
@@ -96,13 +127,14 @@ export const useStore = create<AppState>()(
|
||||
|
||||
toggleDarkMode: () => set((state) => ({ darkMode: !state.darkMode })),
|
||||
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
|
||||
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
|
||||
|
||||
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
|
||||
clearSqlLogs: () => set({ sqlLogs: [] }),
|
||||
}),
|
||||
{
|
||||
name: 'lite-db-storage', // name of the item in the storage (must be unique)
|
||||
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions }), // Don't persist logs
|
||||
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions, queryOptions: state.queryOptions }), // Don't persist logs
|
||||
}
|
||||
)
|
||||
);
|
||||
);
|
||||
|
||||
@@ -15,6 +15,7 @@ export interface ConnectionConfig {
|
||||
database?: string;
|
||||
useSSH?: boolean;
|
||||
ssh?: SSHConfig;
|
||||
redisDB?: number; // Redis database index (0-15)
|
||||
}
|
||||
|
||||
export interface SavedConnection {
|
||||
@@ -22,6 +23,7 @@ export interface SavedConnection {
|
||||
name: string;
|
||||
config: ConnectionConfig;
|
||||
includeDatabases?: string[];
|
||||
includeRedisDatabases?: number[]; // Redis databases to show (0-15)
|
||||
}
|
||||
|
||||
export interface ColumnDefinition {
|
||||
@@ -60,13 +62,14 @@ export interface TriggerDefinition {
|
||||
export interface TabData {
|
||||
id: string;
|
||||
title: string;
|
||||
type: 'query' | 'table' | 'design';
|
||||
type: 'query' | 'table' | 'design' | 'redis-keys' | 'redis-command';
|
||||
connectionId: string;
|
||||
dbName?: string;
|
||||
tableName?: string;
|
||||
query?: string;
|
||||
initialTab?: string;
|
||||
readOnly?: boolean;
|
||||
redisDB?: number; // Redis database index for redis tabs
|
||||
}
|
||||
|
||||
export interface DatabaseNode {
|
||||
@@ -85,3 +88,32 @@ export interface SavedQuery {
|
||||
dbName: string;
|
||||
createdAt: number;
|
||||
}
|
||||
|
||||
// Redis types
|
||||
export interface RedisKeyInfo {
|
||||
key: string;
|
||||
type: string;
|
||||
ttl: number;
|
||||
}
|
||||
|
||||
export interface RedisScanResult {
|
||||
keys: RedisKeyInfo[];
|
||||
cursor: number;
|
||||
}
|
||||
|
||||
export interface RedisValue {
|
||||
type: 'string' | 'hash' | 'list' | 'set' | 'zset';
|
||||
ttl: number;
|
||||
value: any;
|
||||
length: number;
|
||||
}
|
||||
|
||||
export interface RedisDBInfo {
|
||||
index: number;
|
||||
keys: number;
|
||||
}
|
||||
|
||||
export interface ZSetMember {
|
||||
member: string;
|
||||
score: number;
|
||||
}
|
||||
|
||||
198
frontend/src/utils/sql.ts
Normal file
198
frontend/src/utils/sql.ts
Normal file
@@ -0,0 +1,198 @@
|
||||
export type FilterCondition = {
|
||||
id?: number;
|
||||
column?: string;
|
||||
op?: string;
|
||||
value?: string;
|
||||
value2?: string;
|
||||
};
|
||||
|
||||
const normalizeIdentPart = (ident: string) => {
|
||||
let raw = (ident || '').trim();
|
||||
if (!raw) return raw;
|
||||
const first = raw[0];
|
||||
const last = raw[raw.length - 1];
|
||||
if ((first === '"' && last === '"') || (first === '`' && last === '`')) {
|
||||
raw = raw.slice(1, -1).trim();
|
||||
}
|
||||
raw = raw.replace(/["`]/g, '').trim();
|
||||
return raw;
|
||||
};
|
||||
|
||||
// 检查标识符是否需要引号(包含特殊字符或是保留字)
|
||||
const needsQuote = (ident: string): boolean => {
|
||||
if (!ident) return false;
|
||||
// 如果包含特殊字符(非字母、数字、下划线)则需要引号
|
||||
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(ident)) return true;
|
||||
// 常见 SQL 保留字列表(简化版)
|
||||
const reserved = ['select', 'from', 'where', 'table', 'index', 'user', 'order', 'group', 'by', 'limit', 'offset', 'and', 'or', 'not', 'null', 'true', 'false', 'key', 'primary', 'foreign', 'references', 'default', 'constraint', 'create', 'drop', 'alter', 'insert', 'update', 'delete', 'set', 'values', 'into', 'join', 'left', 'right', 'inner', 'outer', 'on', 'as', 'is', 'in', 'like', 'between', 'case', 'when', 'then', 'else', 'end', 'having', 'distinct', 'all', 'any', 'exists', 'union', 'except', 'intersect'];
|
||||
return reserved.includes(ident.toLowerCase());
|
||||
};
|
||||
|
||||
export const quoteIdentPart = (dbType: string, ident: string) => {
|
||||
const raw = normalizeIdentPart(ident);
|
||||
if (!raw) return raw;
|
||||
const dbTypeLower = (dbType || '').toLowerCase();
|
||||
|
||||
if (dbTypeLower === 'mysql') {
|
||||
return `\`${raw.replace(/`/g, '``')}\``;
|
||||
}
|
||||
|
||||
// 对于 KingBase/PostgreSQL,只在必要时加引号
|
||||
if (dbTypeLower === 'kingbase' || dbTypeLower === 'postgres') {
|
||||
if (needsQuote(raw)) {
|
||||
return `"${raw.replace(/"/g, '""')}"`;
|
||||
}
|
||||
// 不加引号,保持原样(数据库会自动转小写处理)
|
||||
return raw;
|
||||
}
|
||||
|
||||
// 其他数据库默认加双引号
|
||||
return `"${raw.replace(/"/g, '""')}"`;
|
||||
};
|
||||
|
||||
export const quoteQualifiedIdent = (dbType: string, ident: string) => {
|
||||
const raw = (ident || '').trim();
|
||||
if (!raw) return raw;
|
||||
const parts = raw.split('.').map(normalizeIdentPart).filter(Boolean);
|
||||
if (parts.length <= 1) return quoteIdentPart(dbType, raw);
|
||||
return parts.map(p => quoteIdentPart(dbType, p)).join('.');
|
||||
};
|
||||
|
||||
export const escapeLiteral = (val: string) => (val || '').replace(/'/g, "''");
|
||||
|
||||
export const parseListValues = (val: string) => {
|
||||
const raw = (val || '').trim();
|
||||
if (!raw) return [];
|
||||
return raw
|
||||
.split(/[\n,,]+/)
|
||||
.map(s => s.trim())
|
||||
.filter(Boolean);
|
||||
};
|
||||
|
||||
export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) => {
|
||||
const whereParts: string[] = [];
|
||||
|
||||
(conditions || []).forEach((cond) => {
|
||||
const op = (cond?.op || '').trim();
|
||||
const column = (cond?.column || '').trim();
|
||||
const value = (cond?.value ?? '').toString();
|
||||
const value2 = (cond?.value2 ?? '').toString();
|
||||
|
||||
if (op === 'CUSTOM') {
|
||||
const expr = value.trim();
|
||||
if (expr) whereParts.push(`(${expr})`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!column) return;
|
||||
|
||||
const col = quoteIdentPart(dbType, column);
|
||||
|
||||
switch (op) {
|
||||
case 'IS_NULL':
|
||||
whereParts.push(`${col} IS NULL`);
|
||||
return;
|
||||
case 'IS_NOT_NULL':
|
||||
whereParts.push(`${col} IS NOT NULL`);
|
||||
return;
|
||||
case 'IS_EMPTY':
|
||||
// 兼容:空值通常理解为 NULL 或空字符串
|
||||
whereParts.push(`(${col} IS NULL OR ${col} = '')`);
|
||||
return;
|
||||
case 'IS_NOT_EMPTY':
|
||||
whereParts.push(`(${col} IS NOT NULL AND ${col} <> '')`);
|
||||
return;
|
||||
case 'BETWEEN': {
|
||||
const v1 = value.trim();
|
||||
const v2 = value2.trim();
|
||||
if (!v1 || !v2) return;
|
||||
whereParts.push(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_BETWEEN': {
|
||||
const v1 = value.trim();
|
||||
const v2 = value2.trim();
|
||||
if (!v1 || !v2) return;
|
||||
whereParts.push(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
return;
|
||||
}
|
||||
case 'IN': {
|
||||
const items = parseListValues(value);
|
||||
if (items.length === 0) return;
|
||||
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
|
||||
whereParts.push(`${col} IN (${list})`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_IN': {
|
||||
const items = parseListValues(value);
|
||||
if (items.length === 0) return;
|
||||
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
|
||||
whereParts.push(`${col} NOT IN (${list})`);
|
||||
return;
|
||||
}
|
||||
case 'CONTAINS': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_CONTAINS': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'STARTS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_STARTS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'ENDS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_ENDS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
case '=':
|
||||
case '!=':
|
||||
case '<':
|
||||
case '<=':
|
||||
case '>':
|
||||
case '>=': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
// 兼容旧值:LIKE
|
||||
if (op.toUpperCase() === 'LIKE') {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
|
||||
};
|
||||
|
||||
49
frontend/wailsjs/go/app/App.d.ts
vendored
49
frontend/wailsjs/go/app/App.d.ts
vendored
@@ -2,6 +2,7 @@
|
||||
// This file is automatically generated. DO NOT EDIT
|
||||
import {connection} from '../models';
|
||||
import {sync} from '../models';
|
||||
import {redis} from '../models';
|
||||
|
||||
export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:connection.ChangeSet):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -35,8 +36,14 @@ export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Pr
|
||||
|
||||
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string,arg5:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportTablesSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>,arg4:boolean):Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportConfigFile():Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
@@ -53,4 +60,46 @@ export function MySQLShowCreateTable(arg1:connection.ConnectionConfig,arg2:strin
|
||||
|
||||
export function OpenSQLFile():Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisConnect(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisDeleteHashField(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisDeleteKeys(arg1:connection.ConnectionConfig,arg2:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisExecuteCommand(arg1:connection.ConnectionConfig,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisFlushDB(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisGetDatabases(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisGetServerInfo(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisGetValue(arg1:connection.ConnectionConfig,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisListPush(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisListSet(arg1:connection.ConnectionConfig,arg2:string,arg3:number,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisRenameKey(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisScanKeys(arg1:connection.ConnectionConfig,arg2:string,arg3:number,arg4:number):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisSelectDB(arg1:connection.ConnectionConfig,arg2:number):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisSetAdd(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisSetHashField(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisSetRemove(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisSetString(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:number):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisSetTTL(arg1:connection.ConnectionConfig,arg2:string,arg3:number):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisTestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisZSetAdd(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<redis.ZSetMember>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisZSetRemove(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function TestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -66,10 +66,22 @@ export function ExportData(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ExportDatabaseSQL(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['ExportDatabaseSQL'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function ExportQuery(arg1, arg2, arg3, arg4, arg5) {
|
||||
return window['go']['app']['App']['ExportQuery'](arg1, arg2, arg3, arg4, arg5);
|
||||
}
|
||||
|
||||
export function ExportTable(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportTable'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ExportTablesSQL(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportTablesSQL'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ImportConfigFile() {
|
||||
return window['go']['app']['App']['ImportConfigFile']();
|
||||
}
|
||||
@@ -102,6 +114,90 @@ export function OpenSQLFile() {
|
||||
return window['go']['app']['App']['OpenSQLFile']();
|
||||
}
|
||||
|
||||
export function RedisConnect(arg1) {
|
||||
return window['go']['app']['App']['RedisConnect'](arg1);
|
||||
}
|
||||
|
||||
export function RedisDeleteHashField(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisDeleteHashField'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RedisDeleteKeys(arg1, arg2) {
|
||||
return window['go']['app']['App']['RedisDeleteKeys'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function RedisExecuteCommand(arg1, arg2) {
|
||||
return window['go']['app']['App']['RedisExecuteCommand'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function RedisFlushDB(arg1) {
|
||||
return window['go']['app']['App']['RedisFlushDB'](arg1);
|
||||
}
|
||||
|
||||
export function RedisGetDatabases(arg1) {
|
||||
return window['go']['app']['App']['RedisGetDatabases'](arg1);
|
||||
}
|
||||
|
||||
export function RedisGetServerInfo(arg1) {
|
||||
return window['go']['app']['App']['RedisGetServerInfo'](arg1);
|
||||
}
|
||||
|
||||
export function RedisGetValue(arg1, arg2) {
|
||||
return window['go']['app']['App']['RedisGetValue'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function RedisListPush(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisListPush'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RedisListSet(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['RedisListSet'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function RedisRenameKey(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisRenameKey'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RedisScanKeys(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['RedisScanKeys'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function RedisSelectDB(arg1, arg2) {
|
||||
return window['go']['app']['App']['RedisSelectDB'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function RedisSetAdd(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisSetAdd'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RedisSetHashField(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['RedisSetHashField'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function RedisSetRemove(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisSetRemove'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RedisSetString(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['RedisSetString'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function RedisSetTTL(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisSetTTL'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RedisTestConnection(arg1) {
|
||||
return window['go']['app']['App']['RedisTestConnection'](arg1);
|
||||
}
|
||||
|
||||
export function RedisZSetAdd(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisZSetAdd'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RedisZSetRemove(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisZSetRemove'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function TestConnection(arg1) {
|
||||
return window['go']['app']['App']['TestConnection'](arg1);
|
||||
}
|
||||
|
||||
@@ -80,6 +80,7 @@ export namespace connection {
|
||||
driver?: string;
|
||||
dsn?: string;
|
||||
timeout?: number;
|
||||
redisDB?: number;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new ConnectionConfig(source);
|
||||
@@ -98,6 +99,7 @@ export namespace connection {
|
||||
this.driver = source["driver"];
|
||||
this.dsn = source["dsn"];
|
||||
this.timeout = source["timeout"];
|
||||
this.redisDB = source["redisDB"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
@@ -140,6 +142,25 @@ export namespace connection {
|
||||
|
||||
}
|
||||
|
||||
export namespace redis {
|
||||
|
||||
export class ZSetMember {
|
||||
member: string;
|
||||
score: number;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new ZSetMember(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.member = source["member"];
|
||||
this.score = source["score"];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export namespace sync {
|
||||
|
||||
export class TableOptions {
|
||||
|
||||
3
go.mod
3
go.mod
@@ -7,6 +7,7 @@ require (
|
||||
gitee.com/chunanyong/dm v1.8.22
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/lib/pq v1.11.1
|
||||
github.com/redis/go-redis/v9 v9.17.3
|
||||
github.com/sijms/go-ora/v2 v2.9.0
|
||||
github.com/wailsapp/wails/v2 v2.11.0
|
||||
golang.org/x/crypto v0.47.0
|
||||
@@ -16,6 +17,8 @@ require (
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/bep/debounce v1.2.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -6,8 +6,16 @@ gitee.com/chunanyong/dm v1.8.22 h1:H7fsrnUIvEA0jlDWew7vwELry1ff+tLMIu2Fk2cIBSg=
|
||||
gitee.com/chunanyong/dm v1.8.22/go.mod h1:EPRJnuPFgbyOFgJ0TRYCTGzhq+ZT4wdyaj/GW/LLcNg=
|
||||
github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY=
|
||||
github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||
@@ -61,6 +69,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4=
|
||||
github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
|
||||
@@ -48,6 +48,8 @@ func (a *App) Shutdown(ctx context.Context) {
|
||||
logger.Error(err, "关闭数据库连接失败")
|
||||
}
|
||||
}
|
||||
// Close all Redis connections
|
||||
CloseAllRedisClients()
|
||||
logger.Infof("资源释放完成,应用已关闭")
|
||||
logger.Close()
|
||||
}
|
||||
@@ -141,9 +143,6 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
|
||||
if len(shortKey) > 12 {
|
||||
shortKey = shortKey[:12]
|
||||
}
|
||||
if config.UseSSH && config.Type != "mysql" {
|
||||
logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config))
|
||||
}
|
||||
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
|
||||
|
||||
a.mu.Lock()
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
)
|
||||
|
||||
// Generic DB Methods
|
||||
@@ -91,16 +94,39 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
query = sanitizeSQLForPgLike(runConfig.Type, query)
|
||||
timeoutSeconds := runConfig.Timeout
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 30
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
|
||||
defer cancel()
|
||||
|
||||
lowerQuery := strings.TrimSpace(strings.ToLower(query))
|
||||
if strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain") {
|
||||
data, columns, err := dbInst.Query(query)
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, query)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(query)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns}
|
||||
} else {
|
||||
affected, err := dbInst.Exec(query)
|
||||
var affected int64
|
||||
if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, query)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(query)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
@@ -213,12 +218,36 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
format = strings.ToLower(format)
|
||||
if format == "sql" {
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := bufio.NewWriterSize(f, 1024*1024)
|
||||
defer w.Flush()
|
||||
|
||||
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, tableName, true); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if err := writeSQLFooter(w, runConfig); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(runConfig.Type, tableName))
|
||||
|
||||
data, columns, err := dbInst.Query(query)
|
||||
@@ -231,71 +260,129 @@ data, columns, err := dbInst.Query(query)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
|
||||
format = strings.ToLower(format)
|
||||
var csvWriter *csv.Writer
|
||||
var jsonEncoder *json.Encoder
|
||||
var isJsonFirstRow = true
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
f.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
csvWriter = csv.NewWriter(f)
|
||||
defer csvWriter.Flush()
|
||||
if err := csvWriter.Write(columns); err != nil {
|
||||
func (a *App) ExportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeData bool) connection.QueryResult {
|
||||
safeDbName := strings.TrimSpace(dbName)
|
||||
if safeDbName == "" {
|
||||
safeDbName = "export"
|
||||
}
|
||||
suffix := "schema"
|
||||
if includeData {
|
||||
suffix = "backup"
|
||||
}
|
||||
defaultFilename := fmt.Sprintf("%s_%s_%dtables.sql", safeDbName, suffix, len(tableNames))
|
||||
if len(tableNames) == 1 && strings.TrimSpace(tableNames[0]) != "" {
|
||||
defaultFilename = fmt.Sprintf("%s_%s.sql", strings.TrimSpace(tableNames[0]), suffix)
|
||||
}
|
||||
|
||||
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
|
||||
Title: "Export Tables (SQL)",
|
||||
DefaultFilename: defaultFilename,
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
tables := make([]string, 0, len(tableNames))
|
||||
seen := make(map[string]struct{}, len(tableNames))
|
||||
for _, t := range tableNames {
|
||||
t = strings.TrimSpace(t)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[t]; ok {
|
||||
continue
|
||||
}
|
||||
seen[t] = struct{}{}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
sort.Strings(tables)
|
||||
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := bufio.NewWriterSize(f, 1024*1024)
|
||||
defer w.Flush()
|
||||
|
||||
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
for _, t := range tables {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
case "json":
|
||||
f.WriteString("[\n")
|
||||
jsonEncoder = json.NewEncoder(f)
|
||||
jsonEncoder.SetIndent(" ", " ")
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
|
||||
seps := make([]string, len(columns))
|
||||
for i := range seps {
|
||||
seps[i] = "---"
|
||||
}
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
|
||||
}
|
||||
if err := writeSQLFooter(w, runConfig); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
for _, rowMap := range data {
|
||||
record := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
val := rowMap[col]
|
||||
if val == nil {
|
||||
record[i] = "NULL"
|
||||
} else {
|
||||
s := fmt.Sprintf("%v", val)
|
||||
if format == "md" {
|
||||
s = strings.ReplaceAll(s, "|", "\\|")
|
||||
s = strings.ReplaceAll(s, "\n", "<br>")
|
||||
}
|
||||
record[i] = s
|
||||
}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
if err := csvWriter.Write(record); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
case "json":
|
||||
if !isJsonFirstRow {
|
||||
f.WriteString(",\n")
|
||||
}
|
||||
if err := jsonEncoder.Encode(rowMap); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
isJsonFirstRow = false
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
|
||||
}
|
||||
func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName string, includeData bool) connection.QueryResult {
|
||||
safeDbName := strings.TrimSpace(dbName)
|
||||
if safeDbName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "dbName required"}
|
||||
}
|
||||
suffix := "schema"
|
||||
if includeData {
|
||||
suffix = "backup"
|
||||
}
|
||||
|
||||
if format == "json" {
|
||||
f.WriteString("\n]")
|
||||
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
|
||||
Title: fmt.Sprintf("Export %s (SQL)", safeDbName),
|
||||
DefaultFilename: fmt.Sprintf("%s_%s.sql", safeDbName, suffix),
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
tables, err := dbInst.GetTables(dbName)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
sort.Strings(tables)
|
||||
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := bufio.NewWriterSize(f, 1024*1024)
|
||||
defer w.Flush()
|
||||
|
||||
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
for _, t := range tables {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
}
|
||||
if err := writeSQLFooter(w, runConfig); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
@@ -340,6 +427,173 @@ func quoteQualifiedIdentByType(dbType string, ident string) string {
|
||||
return strings.Join(quotedParts, ".")
|
||||
}
|
||||
|
||||
func writeSQLHeader(w *bufio.Writer, config connection.ConnectionConfig, dbName string) error {
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
if _, err := w.WriteString(fmt.Sprintf("-- GoNavi SQL Export\n-- Time: %s\n", now)); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(dbName) != "" {
|
||||
if _, err := w.WriteString(fmt.Sprintf("-- Database: %s\n\n", dbName)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" && strings.TrimSpace(dbName) != "" {
|
||||
if _, err := w.WriteString(fmt.Sprintf("USE %s;\n\n", quoteIdentByType("mysql", dbName))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("SET FOREIGN_KEY_CHECKS=0;\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeSQLFooter(w *bufio.Writer, config connection.ConnectionConfig) error {
|
||||
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" {
|
||||
if _, err := w.WriteString("\nSET FOREIGN_KEY_CHECKS=1;\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func qualifyTable(schemaName, tableName string) string {
|
||||
schemaName = strings.TrimSpace(schemaName)
|
||||
tableName = strings.TrimSpace(tableName)
|
||||
if schemaName == "" {
|
||||
return tableName
|
||||
}
|
||||
return schemaName + "." + tableName
|
||||
}
|
||||
|
||||
func ensureSQLTerminator(sql string) string {
|
||||
trimmed := strings.TrimSpace(sql)
|
||||
if trimmed == "" {
|
||||
return sql
|
||||
}
|
||||
if strings.HasSuffix(trimmed, ";") {
|
||||
return sql
|
||||
}
|
||||
return sql + ";"
|
||||
}
|
||||
|
||||
func isMySQLHexLiteral(s string) bool {
|
||||
if len(s) < 3 || !(strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X")) {
|
||||
return false
|
||||
}
|
||||
for i := 2; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatSQLValue(dbType string, v interface{}) string {
|
||||
if v == nil {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
if val {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case int:
|
||||
return strconv.Itoa(val)
|
||||
case int8, int16, int32, int64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case float32:
|
||||
f := float64(val)
|
||||
if math.IsNaN(f) || math.IsInf(f, 0) {
|
||||
return "NULL"
|
||||
}
|
||||
return strconv.FormatFloat(f, 'f', -1, 32)
|
||||
case float64:
|
||||
if math.IsNaN(val) || math.IsInf(val, 0) {
|
||||
return "NULL"
|
||||
}
|
||||
return strconv.FormatFloat(val, 'f', -1, 64)
|
||||
case time.Time:
|
||||
return "'" + val.Format("2006-01-02 15:04:05") + "'"
|
||||
case string:
|
||||
if strings.ToLower(strings.TrimSpace(dbType)) == "mysql" && isMySQLHexLiteral(val) {
|
||||
return val
|
||||
}
|
||||
escaped := strings.ReplaceAll(val, "'", "''")
|
||||
return "'" + escaped + "'"
|
||||
default:
|
||||
escaped := strings.ReplaceAll(fmt.Sprintf("%v", v), "'", "''")
|
||||
return "'" + escaped + "'"
|
||||
}
|
||||
}
|
||||
|
||||
func dumpTableSQL(w *bufio.Writer, dbInst db.Database, config connection.ConnectionConfig, dbName, tableName string, includeData bool) error {
|
||||
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
|
||||
|
||||
if _, err := w.WriteString("\n-- ----------------------------\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString(fmt.Sprintf("-- Table: %s\n", qualifyTable(schemaName, pureTableName))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("-- ----------------------------\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createSQL, err := dbInst.GetCreateStatement(schemaName, pureTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString(ensureSQLTerminator(createSQL)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !includeData {
|
||||
return nil
|
||||
}
|
||||
|
||||
qualified := qualifyTable(schemaName, pureTableName)
|
||||
selectSQL := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.Type, qualified))
|
||||
data, columns, err := dbInst.Query(selectSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
if _, err := w.WriteString("-- (0 rows)\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
quotedCols := make([]string, 0, len(columns))
|
||||
for _, c := range columns {
|
||||
quotedCols = append(quotedCols, quoteIdentByType(config.Type, c))
|
||||
}
|
||||
quotedTable := quoteQualifiedIdentByType(config.Type, qualified)
|
||||
|
||||
for _, row := range data {
|
||||
values := make([]string, 0, len(columns))
|
||||
for _, c := range columns {
|
||||
values = append(values, formatSQLValue(config.Type, row[c]))
|
||||
}
|
||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", quotedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportData exports provided data to a file
|
||||
func (a *App) ExportData(data []map[string]interface{}, columns []string, defaultName string, format string) connection.QueryResult {
|
||||
if defaultName == "" {
|
||||
@@ -359,33 +613,101 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
// ExportQuery exports by executing the provided SELECT query on backend side.
|
||||
// This avoids frontend IPC payload limits when exporting very large/long-text columns (e.g. base64).
|
||||
func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, query string, defaultName string, format string) connection.QueryResult {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return connection.QueryResult{Success: false, Message: "query required"}
|
||||
}
|
||||
|
||||
if defaultName == "" {
|
||||
defaultName = "export"
|
||||
}
|
||||
|
||||
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
|
||||
Title: "Export Query Result",
|
||||
DefaultFilename: fmt.Sprintf("%s.%s", defaultName, strings.ToLower(format)),
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
query = sanitizeSQLForPgLike(runConfig.Type, query)
|
||||
lowerQuery := strings.ToLower(strings.TrimSpace(query))
|
||||
if !(strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "with")) {
|
||||
return connection.QueryResult{Success: false, Message: "Only SELECT/WITH queries are supported"}
|
||||
}
|
||||
|
||||
data, columns, err := dbInst.Query(query)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string, format string) error {
|
||||
format = strings.ToLower(strings.TrimSpace(format))
|
||||
if f == nil {
|
||||
return fmt.Errorf("file required")
|
||||
}
|
||||
|
||||
format = strings.ToLower(format)
|
||||
var csvWriter *csv.Writer
|
||||
var jsonEncoder *json.Encoder
|
||||
var isJsonFirstRow = true
|
||||
isJsonFirstRow := true
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
f.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
if _, err := f.Write([]byte{0xEF, 0xBB, 0xBF}); err != nil {
|
||||
return err
|
||||
}
|
||||
csvWriter = csv.NewWriter(f)
|
||||
defer csvWriter.Flush()
|
||||
if err := csvWriter.Write(columns); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
return err
|
||||
}
|
||||
case "json":
|
||||
f.WriteString("[\n")
|
||||
if _, err := f.WriteString("[\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
jsonEncoder = json.NewEncoder(f)
|
||||
jsonEncoder.SetIndent(" ", " ")
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
|
||||
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | ")); err != nil {
|
||||
return err
|
||||
}
|
||||
seps := make([]string, len(columns))
|
||||
for i := range seps {
|
||||
seps[i] = "---"
|
||||
}
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
|
||||
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | ")); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
|
||||
return fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
|
||||
for _, rowMap := range data {
|
||||
@@ -394,37 +716,51 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
val := rowMap[col]
|
||||
if val == nil {
|
||||
record[i] = "NULL"
|
||||
} else {
|
||||
s := fmt.Sprintf("%v", val)
|
||||
if format == "md" {
|
||||
s = strings.ReplaceAll(s, "|", "\\|")
|
||||
s = strings.ReplaceAll(s, "\n", "<br>")
|
||||
}
|
||||
record[i] = s
|
||||
continue
|
||||
}
|
||||
|
||||
s := fmt.Sprintf("%v", val)
|
||||
if format == "md" {
|
||||
s = strings.ReplaceAll(s, "|", "\\|")
|
||||
s = strings.ReplaceAll(s, "\n", "<br>")
|
||||
}
|
||||
record[i] = s
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
if err := csvWriter.Write(record); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
return err
|
||||
}
|
||||
case "json":
|
||||
if !isJsonFirstRow {
|
||||
f.WriteString(",\n")
|
||||
if _, err := f.WriteString(",\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := jsonEncoder.Encode(rowMap); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
return err
|
||||
}
|
||||
isJsonFirstRow = false
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
|
||||
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | ")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if format == "csv" || format == "xlsx" {
|
||||
csvWriter.Flush()
|
||||
if err := csvWriter.Error(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if format == "json" {
|
||||
f.WriteString("\n]")
|
||||
if _, err := f.WriteString("\n]"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return nil
|
||||
}
|
||||
|
||||
481
internal/app/methods_redis.go
Normal file
481
internal/app/methods_redis.go
Normal file
@@ -0,0 +1,481 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/redis"
|
||||
)
|
||||
|
||||
// Redis client cache
|
||||
var (
|
||||
redisCache = make(map[string]redis.RedisClient)
|
||||
redisCacheMu sync.Mutex
|
||||
)
|
||||
|
||||
// getRedisClient gets or creates a Redis client from cache
|
||||
func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisClient, error) {
|
||||
key := getRedisClientCacheKey(config)
|
||||
shortKey := key
|
||||
if len(shortKey) > 12 {
|
||||
shortKey = shortKey[:12]
|
||||
}
|
||||
logger.Infof("获取 Redis 连接:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
|
||||
|
||||
redisCacheMu.Lock()
|
||||
defer redisCacheMu.Unlock()
|
||||
|
||||
if client, ok := redisCache[key]; ok {
|
||||
logger.Infof("命中 Redis 连接缓存,开始检测可用性:缓存Key=%s", shortKey)
|
||||
if err := client.Ping(); err == nil {
|
||||
logger.Infof("缓存 Redis 连接可用:缓存Key=%s", shortKey)
|
||||
return client, nil
|
||||
} else {
|
||||
logger.Error(err, "缓存 Redis 连接不可用,准备重建:缓存Key=%s", shortKey)
|
||||
}
|
||||
client.Close()
|
||||
delete(redisCache, key)
|
||||
}
|
||||
|
||||
logger.Infof("创建 Redis 客户端实例:缓存Key=%s", shortKey)
|
||||
client := redis.NewRedisClient()
|
||||
if err := client.Connect(config); err != nil {
|
||||
logger.Error(err, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redisCache[key] = client
|
||||
logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func getRedisClientCacheKey(config connection.ConnectionConfig) string {
|
||||
if !config.UseSSH {
|
||||
config.SSH = connection.SSHConfig{}
|
||||
}
|
||||
b, _ := json.Marshal(config)
|
||||
sum := sha256.Sum256(b)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func formatRedisConnSummary(config connection.ConnectionConfig) string {
|
||||
timeoutSeconds := config.Timeout
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 30
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("类型=redis 地址=")
|
||||
b.WriteString(config.Host)
|
||||
b.WriteString(":")
|
||||
b.WriteString(string(rune(config.Port + '0')))
|
||||
b.WriteString(" DB=")
|
||||
b.WriteString(string(rune(config.RedisDB + '0')))
|
||||
|
||||
if config.UseSSH {
|
||||
b.WriteString(" SSH=")
|
||||
b.WriteString(config.SSH.Host)
|
||||
b.WriteString(":")
|
||||
b.WriteString(string(rune(config.SSH.Port + '0')))
|
||||
b.WriteString(" 用户=")
|
||||
b.WriteString(config.SSH.User)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// RedisConnect tests a Redis connection
|
||||
func (a *App) RedisConnect(config connection.ConnectionConfig) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
_, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisConnect 连接失败:%s", formatRedisConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
logger.Infof("RedisConnect 连接成功:%s", formatRedisConnSummary(config))
|
||||
return connection.QueryResult{Success: true, Message: "连接成功"}
|
||||
}
|
||||
|
||||
// RedisTestConnection tests a Redis connection (alias for RedisConnect)
|
||||
func (a *App) RedisTestConnection(config connection.ConnectionConfig) connection.QueryResult {
|
||||
return a.RedisConnect(config)
|
||||
}
|
||||
|
||||
// RedisScanKeys scans keys matching a pattern
|
||||
func (a *App) RedisScanKeys(config connection.ConnectionConfig, pattern string, cursor uint64, count int64) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
result, err := client.ScanKeys(pattern, cursor, count)
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisScanKeys 扫描失败:pattern=%s", pattern)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: result}
|
||||
}
|
||||
|
||||
// RedisGetValue gets the value of a key
|
||||
func (a *App) RedisGetValue(config connection.ConnectionConfig, key string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
value, err := client.GetValue(key)
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisGetValue 获取失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: value}
|
||||
}
|
||||
|
||||
// RedisSetString sets a string value
|
||||
func (a *App) RedisSetString(config connection.ConnectionConfig, key, value string, ttl int64) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.SetString(key, value, ttl); err != nil {
|
||||
logger.Error(err, "RedisSetString 设置失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "设置成功"}
|
||||
}
|
||||
|
||||
// RedisSetHashField sets a field in a hash
|
||||
func (a *App) RedisSetHashField(config connection.ConnectionConfig, key, field, value string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.SetHashField(key, field, value); err != nil {
|
||||
logger.Error(err, "RedisSetHashField 设置失败:key=%s field=%s", key, field)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "设置成功"}
|
||||
}
|
||||
|
||||
// RedisDeleteKeys deletes one or more keys
|
||||
func (a *App) RedisDeleteKeys(config connection.ConnectionConfig, keys []string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
deleted, err := client.DeleteKeys(keys)
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisDeleteKeys 删除失败:keys=%v", keys)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: map[string]int64{"deleted": deleted}}
|
||||
}
|
||||
|
||||
// RedisSetTTL sets the TTL of a key
|
||||
func (a *App) RedisSetTTL(config connection.ConnectionConfig, key string, ttl int64) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.SetTTL(key, ttl); err != nil {
|
||||
logger.Error(err, "RedisSetTTL 设置失败:key=%s ttl=%d", key, ttl)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "设置成功"}
|
||||
}
|
||||
|
||||
// RedisExecuteCommand executes a raw Redis command
|
||||
func (a *App) RedisExecuteCommand(config connection.ConnectionConfig, command string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
// Parse command string into args
|
||||
args := parseRedisCommand(command)
|
||||
if len(args) == 0 {
|
||||
return connection.QueryResult{Success: false, Message: "命令不能为空"}
|
||||
}
|
||||
|
||||
result, err := client.ExecuteCommand(args)
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisExecuteCommand 执行失败:command=%s", command)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: result}
|
||||
}
|
||||
|
||||
// parseRedisCommand parses a Redis command string into arguments
|
||||
func parseRedisCommand(command string) []string {
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var args []string
|
||||
var current strings.Builder
|
||||
inQuote := false
|
||||
quoteChar := rune(0)
|
||||
|
||||
for _, ch := range command {
|
||||
if inQuote {
|
||||
if ch == quoteChar {
|
||||
inQuote = false
|
||||
args = append(args, current.String())
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(ch)
|
||||
}
|
||||
} else {
|
||||
if ch == '"' || ch == '\'' {
|
||||
inQuote = true
|
||||
quoteChar = ch
|
||||
} else if ch == ' ' || ch == '\t' {
|
||||
if current.Len() > 0 {
|
||||
args = append(args, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
} else {
|
||||
current.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
args = append(args, current.String())
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// RedisGetServerInfo returns server information
|
||||
func (a *App) RedisGetServerInfo(config connection.ConnectionConfig) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
info, err := client.GetServerInfo()
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisGetServerInfo 获取失败")
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: info}
|
||||
}
|
||||
|
||||
// RedisGetDatabases returns information about all databases
|
||||
func (a *App) RedisGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
dbs, err := client.GetDatabases()
|
||||
if err != nil {
|
||||
logger.Error(err, "RedisGetDatabases 获取失败")
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: dbs}
|
||||
}
|
||||
|
||||
// RedisSelectDB selects a database
|
||||
func (a *App) RedisSelectDB(config connection.ConnectionConfig, dbIndex int) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
config.RedisDB = dbIndex
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.SelectDB(dbIndex); err != nil {
|
||||
logger.Error(err, "RedisSelectDB 切换失败:db=%d", dbIndex)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "切换成功"}
|
||||
}
|
||||
|
||||
// RedisRenameKey renames a key
|
||||
func (a *App) RedisRenameKey(config connection.ConnectionConfig, oldKey, newKey string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.RenameKey(oldKey, newKey); err != nil {
|
||||
logger.Error(err, "RedisRenameKey 重命名失败:%s -> %s", oldKey, newKey)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "重命名成功"}
|
||||
}
|
||||
|
||||
// RedisDeleteHashField deletes fields from a hash
|
||||
func (a *App) RedisDeleteHashField(config connection.ConnectionConfig, key string, fields []string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.DeleteHashField(key, fields...); err != nil {
|
||||
logger.Error(err, "RedisDeleteHashField 删除失败:key=%s fields=%v", key, fields)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "删除成功"}
|
||||
}
|
||||
|
||||
// RedisListPush pushes values to a list
|
||||
func (a *App) RedisListPush(config connection.ConnectionConfig, key string, values []string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.ListPush(key, values...); err != nil {
|
||||
logger.Error(err, "RedisListPush 添加失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "添加成功"}
|
||||
}
|
||||
|
||||
// RedisListSet sets a value at an index in a list
|
||||
func (a *App) RedisListSet(config connection.ConnectionConfig, key string, index int64, value string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.ListSet(key, index, value); err != nil {
|
||||
logger.Error(err, "RedisListSet 设置失败:key=%s index=%d", key, index)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "设置成功"}
|
||||
}
|
||||
|
||||
// RedisSetAdd adds members to a set
|
||||
func (a *App) RedisSetAdd(config connection.ConnectionConfig, key string, members []string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.SetAdd(key, members...); err != nil {
|
||||
logger.Error(err, "RedisSetAdd 添加失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "添加成功"}
|
||||
}
|
||||
|
||||
// RedisSetRemove removes members from a set
|
||||
func (a *App) RedisSetRemove(config connection.ConnectionConfig, key string, members []string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.SetRemove(key, members...); err != nil {
|
||||
logger.Error(err, "RedisSetRemove 删除失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "删除成功"}
|
||||
}
|
||||
|
||||
// RedisZSetAdd adds members to a sorted set
|
||||
func (a *App) RedisZSetAdd(config connection.ConnectionConfig, key string, members []redis.ZSetMember) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.ZSetAdd(key, members...); err != nil {
|
||||
logger.Error(err, "RedisZSetAdd 添加失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "添加成功"}
|
||||
}
|
||||
|
||||
// RedisZSetRemove removes members from a sorted set
|
||||
func (a *App) RedisZSetRemove(config connection.ConnectionConfig, key string, members []string) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.ZSetRemove(key, members...); err != nil {
|
||||
logger.Error(err, "RedisZSetRemove 删除失败:key=%s", key)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "删除成功"}
|
||||
}
|
||||
|
||||
// RedisFlushDB flushes the current database
|
||||
func (a *App) RedisFlushDB(config connection.ConnectionConfig) connection.QueryResult {
|
||||
config.Type = "redis"
|
||||
client, err := a.getRedisClient(config)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if err := client.FlushDB(); err != nil {
|
||||
logger.Error(err, "RedisFlushDB 清空失败")
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "清空成功"}
|
||||
}
|
||||
|
||||
// CloseAllRedisClients closes all cached Redis clients (called on shutdown)
|
||||
func CloseAllRedisClients() {
|
||||
redisCacheMu.Lock()
|
||||
defer redisCacheMu.Unlock()
|
||||
|
||||
for key, client := range redisCache {
|
||||
if client != nil {
|
||||
client.Close()
|
||||
logger.Infof("已关闭 Redis 连接:%s", key[:12])
|
||||
}
|
||||
}
|
||||
redisCache = make(map[string]redis.RedisClient)
|
||||
}
|
||||
236
internal/app/sql_sanitize.go
Normal file
236
internal/app/sql_sanitize.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func sanitizeSQLForPgLike(dbType string, query string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase":
|
||||
// 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。
|
||||
// 这里做有限次数的迭代,直到输出不再变化。
|
||||
out := query
|
||||
for i := 0; i < 3; i++ {
|
||||
fixed := fixBrokenDoubleDoubleQuotedIdent(out)
|
||||
if fixed == out {
|
||||
break
|
||||
}
|
||||
out = fixed
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return query
|
||||
}
|
||||
}
|
||||
|
||||
// fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like:
|
||||
//
|
||||
// SELECT * FROM ""schema"".""table""
|
||||
//
|
||||
// which can be produced when a quoted identifier gets wrapped by quotes again.
|
||||
//
|
||||
// It is intentionally conservative:
|
||||
// - only runs outside strings/comments/dollar-quoted blocks
|
||||
// - does not touch valid escaped-quote sequences inside quoted identifiers (e.g. "a""b")
|
||||
func fixBrokenDoubleDoubleQuotedIdent(query string) string {
|
||||
if !strings.Contains(query, `""`) {
|
||||
return query
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(query))
|
||||
|
||||
inSingle := false
|
||||
inDoubleIdent := false
|
||||
inLineComment := false
|
||||
inBlockComment := false
|
||||
dollarTag := ""
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
ch := query[i]
|
||||
next := byte(0)
|
||||
if i+1 < len(query) {
|
||||
next = query[i+1]
|
||||
}
|
||||
|
||||
if inLineComment {
|
||||
b.WriteByte(ch)
|
||||
if ch == '\n' {
|
||||
inLineComment = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inBlockComment {
|
||||
b.WriteByte(ch)
|
||||
if ch == '*' && next == '/' {
|
||||
b.WriteByte('/')
|
||||
i++
|
||||
inBlockComment = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if dollarTag != "" {
|
||||
if strings.HasPrefix(query[i:], dollarTag) {
|
||||
b.WriteString(dollarTag)
|
||||
i += len(dollarTag) - 1
|
||||
dollarTag = ""
|
||||
continue
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if inSingle {
|
||||
b.WriteByte(ch)
|
||||
if ch == '\'' {
|
||||
// escaped single quote
|
||||
if next == '\'' {
|
||||
b.WriteByte('\'')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
inSingle = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inDoubleIdent {
|
||||
b.WriteByte(ch)
|
||||
if ch == '"' {
|
||||
// escaped quote inside identifier
|
||||
if next == '"' {
|
||||
b.WriteByte('"')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
inDoubleIdent = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// --- Outside of all string/comment blocks ---
|
||||
if ch == '-' && next == '-' {
|
||||
b.WriteByte(ch)
|
||||
b.WriteByte('-')
|
||||
i++
|
||||
inLineComment = true
|
||||
continue
|
||||
}
|
||||
if ch == '/' && next == '*' {
|
||||
b.WriteByte(ch)
|
||||
b.WriteByte('*')
|
||||
i++
|
||||
inBlockComment = true
|
||||
continue
|
||||
}
|
||||
if ch == '\'' {
|
||||
b.WriteByte(ch)
|
||||
inSingle = true
|
||||
continue
|
||||
}
|
||||
if ch == '$' {
|
||||
if tag := parseDollarTag(query[i:]); tag != "" {
|
||||
b.WriteString(tag)
|
||||
i += len(tag) - 1
|
||||
dollarTag = tag
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
// Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
|
||||
// Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side).
|
||||
if next == '"' {
|
||||
if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok {
|
||||
b.WriteString(replacement)
|
||||
i = advance - 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteByte(ch)
|
||||
inDoubleIdent = true
|
||||
continue
|
||||
}
|
||||
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) {
|
||||
// start points at the first quote of a broken identifier, usually like:
|
||||
// ""ident"" / ""ident""" / """"ident""""
|
||||
if start < 0 || start+1 >= len(query) {
|
||||
return "", 0, false
|
||||
}
|
||||
if query[start] != '"' || query[start+1] != '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
if start > 0 && query[start-1] == '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
runLen := 0
|
||||
for start+runLen < len(query) && query[start+runLen] == '"' {
|
||||
runLen++
|
||||
}
|
||||
if runLen < 2 || runLen%2 == 1 {
|
||||
// Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes.
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
contentStart := start + runLen
|
||||
j := contentStart
|
||||
for j < len(query) {
|
||||
if query[j] == '"' {
|
||||
endRunLen := 0
|
||||
for j+endRunLen < len(query) && query[j+endRunLen] == '"' {
|
||||
endRunLen++
|
||||
}
|
||||
if endRunLen >= 2 {
|
||||
content := strings.TrimSpace(query[contentStart:j])
|
||||
if looksLikeIdentifierContent(content) {
|
||||
return `"` + content + `"`, j + endRunLen, true
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
}
|
||||
// Fast abort: identifier-like content should not span lines.
|
||||
if query[j] == '\n' || query[j] == '\r' {
|
||||
break
|
||||
}
|
||||
j++
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func looksLikeIdentifierContent(s string) bool {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r == '_' || r == '$' || r == '-' || unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func parseDollarTag(s string) string {
|
||||
// Match: $tag$ where tag is [A-Za-z0-9_]* (can be empty => $$)
|
||||
if len(s) < 2 || s[0] != '$' {
|
||||
return ""
|
||||
}
|
||||
for i := 1; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c == '$' {
|
||||
return s[:i+1]
|
||||
}
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
55
internal/app/sql_sanitize_test.go
Normal file
55
internal/app/sql_sanitize_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package app
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes(t *testing.T) {
|
||||
in := `SELECT * FROM ""ldf_server"".""t_user"" LIMIT 1`
|
||||
out := sanitizeSQLForPgLike("kingbase", in)
|
||||
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithExtraQuotes(t *testing.T) {
|
||||
in := `SELECT * FROM ""ldf_server""".""t_user"" LIMIT 1`
|
||||
out := sanitizeSQLForPgLike("kingbase", in)
|
||||
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithQuadQuotes(t *testing.T) {
|
||||
in := `SELECT * FROM """"ldf_server"""".""t_user"" LIMIT 1`
|
||||
out := sanitizeSQLForPgLike("kingbase", in)
|
||||
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_DoesNotTouchEscapedQuotesInsideIdentifier(t *testing.T) {
|
||||
in := `SELECT "a""b" FROM "t""x"`
|
||||
out := sanitizeSQLForPgLike("postgres", in)
|
||||
if out != in {
|
||||
t.Fatalf("should keep valid escaped quotes inside identifier:\nIN: %s\nOUT: %s", in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_DoesNotTouchDollarQuotedStrings(t *testing.T) {
|
||||
in := "SELECT $$\"\"ldf_server\"\"$$, \"\"ldf_server\"\""
|
||||
out := sanitizeSQLForPgLike("postgres", in)
|
||||
want := "SELECT $$\"\"ldf_server\"\"$$, \"ldf_server\""
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output for dollar quoted string:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
|
||||
in := `SELECT * FROM ""ldf_server""`
|
||||
out := sanitizeSQLForPgLike("mysql", in)
|
||||
if out != in {
|
||||
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
|
||||
}
|
||||
}
|
||||
@@ -19,9 +19,10 @@ type ConnectionConfig struct {
|
||||
Database string `json:"database"`
|
||||
UseSSH bool `json:"useSSH"`
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
}
|
||||
|
||||
// QueryResult is the standard response format for Wails methods
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -57,6 +58,20 @@ func (c *CustomDB) Ping() error {
|
||||
return c.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := c.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -67,33 +82,18 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if c.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := c.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = normalizeQueryValue(values[i])
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (c *CustomDB) Exec(query string) (int64, error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -19,6 +21,7 @@ import (
|
||||
type DamengDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
@@ -26,16 +29,6 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// or dm://user:password@host:port
|
||||
|
||||
address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
|
||||
if config.UseSSH {
|
||||
// SSH logic similar to others, assumes port forwarding
|
||||
_, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
// DM driver likely uses standard net.Dial, so we might need a local listener
|
||||
// or assume port forwarding is handled externally or implicitly via "tcp" override if driver allows.
|
||||
// Similar to Oracle, we skip complex custom dialer injection for now.
|
||||
}
|
||||
}
|
||||
|
||||
escapedPassword := url.PathEscape(config.Password)
|
||||
q := url.Values{}
|
||||
if config.Database != "" {
|
||||
@@ -55,7 +48,42 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := d.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("达梦数据库使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
d.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = d.getDSN(localConfig)
|
||||
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = d.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("dm", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
@@ -69,6 +97,15 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
|
||||
func (d *DamengDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if d.forwarder != nil {
|
||||
if err := d.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭达梦数据库 SSH 端口转发失败:%v", err)
|
||||
}
|
||||
d.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if d.conn != nil {
|
||||
return d.conn.Close()
|
||||
}
|
||||
@@ -88,6 +125,20 @@ func (d *DamengDB) Ping() error {
|
||||
return d.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (d *DamengDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := d.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -98,33 +149,18 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if d.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := d.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = normalizeQueryValue(values[i])
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *DamengDB) Exec(query string) (int64, error) {
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -16,6 +20,7 @@ import (
|
||||
type KingbaseDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
func quoteConnValue(v string) string {
|
||||
@@ -57,20 +62,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
address := config.Host
|
||||
port := config.Port
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
// Kingbase/Postgres lib/pq allows custom dialer via "host" if using unix socket,
|
||||
// but for custom network it's harder.
|
||||
// Ideally we use a local forwarder.
|
||||
// For now, we assume standard TCP or handle SSH externally.
|
||||
// If we implement the net.Dial override for "kingbase" driver (which might use lib/pq internally),
|
||||
// we might need to check if it supports "cloudsql" style or similar custom dialers.
|
||||
// Similar to others, skipping SSH deep integration here for now.
|
||||
_ = netName
|
||||
}
|
||||
}
|
||||
|
||||
// Construct DSN
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
|
||||
quoteConnValue(address),
|
||||
@@ -85,7 +76,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := k.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
k.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = k.getDSN(localConfig)
|
||||
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = k.getDSN(config)
|
||||
}
|
||||
|
||||
// Open using "kingbase" driver
|
||||
db, err := sql.Open("kingbase", dsn)
|
||||
if err != nil {
|
||||
@@ -100,6 +126,15 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if k.forwarder != nil {
|
||||
if err := k.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
|
||||
}
|
||||
k.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if k.conn != nil {
|
||||
return k.conn.Close()
|
||||
}
|
||||
@@ -119,6 +154,20 @@ func (k *KingbaseDB) Ping() error {
|
||||
return k.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := k.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -129,33 +178,18 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if k.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := k.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = normalizeQueryValue(values[i])
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Exec(query string) (int64, error) {
|
||||
@@ -223,15 +257,84 @@ func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
schema := "public"
|
||||
if dbName != "" {
|
||||
schema = dbName
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '%s' AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, schema, tableName)
|
||||
// 如果仍然没有 schema,使用 current_schema()
|
||||
// 这样可以自动匹配当前连接的 search_path
|
||||
if schema == "" {
|
||||
return k.getColumnsWithCurrentSchema(table)
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
// 移除前后的双引号(如果存在)
|
||||
s = strings.Trim(s, "\"")
|
||||
// 转义单引号
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '%s' AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, esc(schema), esc(table))
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
}
|
||||
|
||||
if row["column_default"] != nil {
|
||||
def := fmt.Sprintf("%v", row["column_default"])
|
||||
col.Default = &def
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
|
||||
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 使用 current_schema() 获取当前schema
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = current_schema() AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, esc(table))
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -257,32 +360,76 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
// Postgres/Kingbase index query
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname as index_name,
|
||||
a.attname as column_name,
|
||||
ix.indisunique as is_unique
|
||||
FROM
|
||||
pg_class t,
|
||||
pg_class i,
|
||||
pg_index ix,
|
||||
pg_attribute a,
|
||||
pg_namespace n
|
||||
WHERE
|
||||
t.oid = ix.indrelid
|
||||
AND i.oid = ix.indexrelid
|
||||
AND a.attrelid = t.oid
|
||||
AND a.attnum = ANY(ix.indkey)
|
||||
AND t.relkind = 'r'
|
||||
AND t.relname = '%s'
|
||||
AND n.oid = t.relnamespace
|
||||
AND n.nspname = '%s'
|
||||
`, tableName, "public") // Default to public if dbName (schema) not clear.
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
if dbName != "" {
|
||||
// Update query to use dbName as schema
|
||||
query = strings.Replace(query, "'public'", fmt.Sprintf("'%s'", dbName), 1)
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 构建查询:如果没有指定schema,使用current_schema()
|
||||
var query string
|
||||
if schema != "" {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname as index_name,
|
||||
a.attname as column_name,
|
||||
ix.indisunique as is_unique
|
||||
FROM
|
||||
pg_class t,
|
||||
pg_class i,
|
||||
pg_index ix,
|
||||
pg_attribute a,
|
||||
pg_namespace n
|
||||
WHERE
|
||||
t.oid = ix.indrelid
|
||||
AND i.oid = ix.indexrelid
|
||||
AND a.attrelid = t.oid
|
||||
AND a.attnum = ANY(ix.indkey)
|
||||
AND t.relkind = 'r'
|
||||
AND t.relname = '%s'
|
||||
AND n.oid = t.relnamespace
|
||||
AND n.nspname = '%s'
|
||||
`, esc(table), esc(schema))
|
||||
} else {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname as index_name,
|
||||
a.attname as column_name,
|
||||
ix.indisunique as is_unique
|
||||
FROM
|
||||
pg_class t,
|
||||
pg_class i,
|
||||
pg_index ix,
|
||||
pg_attribute a,
|
||||
pg_namespace n
|
||||
WHERE
|
||||
t.oid = ix.indrelid
|
||||
AND i.oid = ix.indexrelid
|
||||
AND a.attrelid = t.oid
|
||||
AND a.attnum = ANY(ix.indkey)
|
||||
AND t.relkind = 'r'
|
||||
AND t.relname = '%s'
|
||||
AND n.oid = t.relnamespace
|
||||
AND n.nspname = current_schema()
|
||||
`, esc(table))
|
||||
}
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
@@ -311,27 +458,67 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
schema := "public"
|
||||
if dbName != "" {
|
||||
schema = dbName
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
|
||||
tableName, schema)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 构建查询:如果没有指定schema,使用current_schema()
|
||||
var query string
|
||||
if schema != "" {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
|
||||
esc(table), esc(schema))
|
||||
} else {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
|
||||
esc(table))
|
||||
}
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -353,9 +540,43 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
query := fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s'`, tableName)
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 构建查询:如果指定了schema,也加上schema条件
|
||||
var query string
|
||||
if schema != "" {
|
||||
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
|
||||
esc(table), esc(schema))
|
||||
} else {
|
||||
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
|
||||
esc(table))
|
||||
}
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -380,14 +601,13 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
schema := "public"
|
||||
if dbName != "" {
|
||||
schema = dbName
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '%s'`, schema)
|
||||
// dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
|
||||
query := `
|
||||
SELECT table_schema, table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
AND table_schema NOT LIKE 'pg_%'
|
||||
ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -396,8 +616,14 @@ func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinition
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||||
table := fmt.Sprintf("%v", row["table_name"])
|
||||
tableName := table
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||||
}
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: fmt.Sprintf("%v", row["table_name"]),
|
||||
TableName: tableName,
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -76,6 +77,20 @@ func (m *MySQLDB) Ping() error {
|
||||
return m.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := m.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -86,33 +101,18 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := m.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = normalizeQueryValue(values[i])
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Exec(query string) (int64, error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -19,6 +21,7 @@ import (
|
||||
type OracleDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
@@ -28,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
database = config.User // Default to user service/schema if empty?
|
||||
}
|
||||
|
||||
if config.UseSSH {
|
||||
_, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
// Oracle driver might not support custom dialer via DSN easily without extra config
|
||||
// But go-ora v2 supports some advanced options.
|
||||
// For simplicity, we assume standard TCP or we might need a workaround for SSH.
|
||||
// go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open.
|
||||
// But for now, let's just use the address.
|
||||
// SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...).
|
||||
// We might need to forward a local port if using SSH.
|
||||
// Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...",
|
||||
// we need to see if go-ora supports custom networks.
|
||||
// Checking go-ora docs (simulated): It supports "unix" and "tcp".
|
||||
// We might need to map the custom network to a local proxy.
|
||||
// For now, we will assume direct connection or handle SSH separately later.
|
||||
// We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial.
|
||||
// Note: go-ora connection string: oracle://user:pass@host:port/service
|
||||
// It parses host/port. It doesn't easily take a custom "network" parameter in URL.
|
||||
// We will proceed with standard TCP string.
|
||||
}
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "oracle",
|
||||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
@@ -61,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := o.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
o.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = o.getDSN(localConfig)
|
||||
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = o.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("oracle", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
@@ -75,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
|
||||
func (o *OracleDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if o.forwarder != nil {
|
||||
if err := o.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err)
|
||||
}
|
||||
o.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if o.conn != nil {
|
||||
return o.conn.Close()
|
||||
}
|
||||
@@ -94,6 +119,20 @@ func (o *OracleDB) Ping() error {
|
||||
return o.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if o.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := o.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if o.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -104,33 +143,18 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if o.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := o.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = normalizeQueryValue(values[i])
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (o *OracleDB) Exec(query string) (int64, error) {
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
||||
type PostgresDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
|
||||
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||||
dbname := config.Database
|
||||
@@ -41,7 +48,42 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := p.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
p.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false // Disable SSH flag for DSN generation
|
||||
|
||||
dsn = p.getDSN(localConfig)
|
||||
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = p.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
@@ -56,7 +98,17 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func (p *PostgresDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if p.forwarder != nil {
|
||||
if err := p.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err)
|
||||
}
|
||||
p.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if p.conn != nil {
|
||||
return p.conn.Close()
|
||||
}
|
||||
@@ -76,6 +128,20 @@ func (p *PostgresDB) Ping() error {
|
||||
return p.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := p.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -86,33 +152,18 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if p.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := p.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = normalizeQueryValue(values[i])
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Exec(query string) (int64, error) {
|
||||
@@ -167,21 +218,306 @@ func (p *PostgresDB) GetCreateStatement(dbName, tableName string) (string, error
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return []connection.ColumnDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
col_description(a.attrelid, a.attnum) AS comment,
|
||||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||||
LEFT JOIN (
|
||||
SELECT i.indrelid, a3.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||||
WHERE i.indisprimary
|
||||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||||
WHERE c.relkind IN ('r', 'p')
|
||||
AND n.nspname = '%s'
|
||||
AND c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum`, esc(schema), esc(table))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if v, ok := row["comment"]; ok && v != nil {
|
||||
col.Comment = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
if v, ok := row["column_default"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||||
col.Extra = "auto_increment"
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return []connection.IndexDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname AS index_name,
|
||||
a.attname AS column_name,
|
||||
ix.indisunique AS is_unique,
|
||||
x.ordinality AS seq_in_index,
|
||||
am.amname AS index_type
|
||||
FROM pg_class t
|
||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||
JOIN pg_index ix ON t.oid = ix.indrelid
|
||||
JOIN pg_class i ON i.oid = ix.indexrelid
|
||||
JOIN pg_am am ON i.relam = am.oid
|
||||
JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE
|
||||
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
|
||||
WHERE t.relkind IN ('r', 'p')
|
||||
AND t.relname = '%s'
|
||||
AND n.nspname = '%s'
|
||||
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseBool := func(v interface{}) bool {
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return val
|
||||
case string:
|
||||
s := strings.ToLower(strings.TrimSpace(val))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
default:
|
||||
s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
}
|
||||
}
|
||||
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
// best effort
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
isUnique := false
|
||||
if v, ok := row["is_unique"]; ok && v != nil {
|
||||
isUnique = parseBool(v)
|
||||
}
|
||||
|
||||
nonUnique := 1
|
||||
if isUnique {
|
||||
nonUnique = 0
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if v, ok := row["seq_in_index"]; ok && v != nil {
|
||||
seq = parseInt(v)
|
||||
}
|
||||
|
||||
indexType := ""
|
||||
if v, ok := row["index_type"]; ok && v != nil {
|
||||
indexType = strings.ToUpper(fmt.Sprintf("%v", v))
|
||||
}
|
||||
if indexType == "" {
|
||||
indexType = "BTREE"
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["index_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: indexType,
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name AS constraint_name,
|
||||
kcu.column_name AS column_name,
|
||||
ccu.table_schema AS foreign_table_schema,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_name = '%s'
|
||||
AND tc.table_schema = '%s'
|
||||
ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
refSchema := ""
|
||||
if v, ok := row["foreign_table_schema"]; ok && v != nil {
|
||||
refSchema = fmt.Sprintf("%v", v)
|
||||
}
|
||||
refTable := fmt.Sprintf("%v", row["foreign_table_name"])
|
||||
refTableName := refTable
|
||||
if strings.TrimSpace(refSchema) != "" {
|
||||
refTableName = fmt.Sprintf("%s.%s", refSchema, refTable)
|
||||
}
|
||||
|
||||
fk := connection.ForeignKeyDefinition{
|
||||
Name: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
RefTableName: refTableName,
|
||||
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
|
||||
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT trigger_name, action_timing, event_manipulation, action_statement
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s'
|
||||
AND event_object_schema = '%s'
|
||||
ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
trig := connection.TriggerDefinition{
|
||||
Name: fmt.Sprintf("%v", row["trigger_name"]),
|
||||
Timing: fmt.Sprintf("%v", row["action_timing"]),
|
||||
Event: fmt.Sprintf("%v", row["event_manipulation"]),
|
||||
Statement: fmt.Sprintf("%v", row["action_statement"]),
|
||||
}
|
||||
triggers = append(triggers, trig)
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
return []connection.ColumnDefinitionWithTable{}, nil
|
||||
query := `
|
||||
SELECT table_schema, table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
AND table_schema NOT LIKE 'pg_%'
|
||||
ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||||
table := fmt.Sprintf("%v", row["table_name"])
|
||||
tableName := table
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||||
}
|
||||
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: tableName,
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
}
|
||||
cols = append(cols, col)
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package db
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@@ -9,13 +11,17 @@ import (
|
||||
// normalizeQueryValue normalizes driver-returned values for UI/JSON transport.
|
||||
// 当前主要处理 []byte:如果是可读文本则转为 string,否则转为十六进制字符串,避免前端出现“空白值”。
|
||||
func normalizeQueryValue(v interface{}) interface{} {
|
||||
return normalizeQueryValueWithDBType(v, "")
|
||||
}
|
||||
|
||||
func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} {
|
||||
if b, ok := v.([]byte); ok {
|
||||
return bytesToReadableString(b)
|
||||
return bytesToDisplayValue(b, databaseTypeName)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func bytesToReadableString(b []byte) interface{} {
|
||||
func bytesToDisplayValue(b []byte, databaseTypeName string) interface{} {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -23,6 +29,18 @@ func bytesToReadableString(b []byte) interface{} {
|
||||
return ""
|
||||
}
|
||||
|
||||
dbType := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
if isBitLikeDBType(dbType) {
|
||||
if u, ok := bytesToUint64(b); ok {
|
||||
// JS number precision is limited; keep large bitmasks as string.
|
||||
const maxSafeInteger = 9007199254740991 // 2^53 - 1
|
||||
if u <= maxSafeInteger {
|
||||
return int64(u)
|
||||
}
|
||||
return fmt.Sprintf("%d", u)
|
||||
}
|
||||
}
|
||||
|
||||
if utf8.Valid(b) {
|
||||
s := string(b)
|
||||
if isMostlyPrintable(s) {
|
||||
@@ -30,9 +48,47 @@ func bytesToReadableString(b []byte) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: some drivers return BIT(1) as []byte{0} / []byte{1} without type info.
|
||||
if dbType == "" && len(b) == 1 && (b[0] == 0 || b[0] == 1) {
|
||||
return int64(b[0])
|
||||
}
|
||||
|
||||
return bytesToReadableString(b)
|
||||
}
|
||||
|
||||
func bytesToReadableString(b []byte) interface{} {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "0x" + hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func isBitLikeDBType(typeName string) bool {
|
||||
if typeName == "" {
|
||||
return false
|
||||
}
|
||||
switch typeName {
|
||||
case "BIT", "VARBIT":
|
||||
return true
|
||||
default:
|
||||
}
|
||||
return strings.HasPrefix(typeName, "BIT")
|
||||
}
|
||||
|
||||
func bytesToUint64(b []byte) (uint64, bool) {
|
||||
if len(b) == 0 || len(b) > 8 {
|
||||
return 0, false
|
||||
}
|
||||
var u uint64
|
||||
for _, v := range b {
|
||||
u = (u << 8) | uint64(v)
|
||||
}
|
||||
return u, true
|
||||
}
|
||||
|
||||
func isMostlyPrintable(s string) bool {
|
||||
if s == "" {
|
||||
return true
|
||||
|
||||
44
internal/db/query_value_test.go
Normal file
44
internal/db/query_value_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_BitBytes(t *testing.T) {
|
||||
v := normalizeQueryValueWithDBType([]byte{0x00}, "BIT")
|
||||
if v != int64(0) {
|
||||
t.Fatalf("BIT 0x00 期望为 0,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0x01}, "bit")
|
||||
if v != int64(1) {
|
||||
t.Fatalf("BIT 0x01 期望为 1,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0x01, 0x02}, "BIT VARYING")
|
||||
if v != int64(258) {
|
||||
t.Fatalf("BIT 0x0102 期望为 258,实际=%v(%T)", v, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_BitLargeAsString(t *testing.T) {
|
||||
v := normalizeQueryValueWithDBType([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "BIT")
|
||||
if s, ok := v.(string); !ok || s != "18446744073709551615" {
|
||||
t.Fatalf("BIT 0xffffffffffffffff 期望为 string(18446744073709551615),实际=%v(%T)", v, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_ByteFallbacks(t *testing.T) {
|
||||
v := normalizeQueryValueWithDBType([]byte("abc"), "")
|
||||
if v != "abc" {
|
||||
t.Fatalf("文本 []byte 期望返回 string,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0x00}, "")
|
||||
if v != int64(0) {
|
||||
t.Fatalf("未知类型 0x00 期望返回 0,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0xff}, "")
|
||||
if v != "0xff" {
|
||||
t.Fatalf("未知类型 0xff 期望返回 0xff,实际=%v(%T)", v, v)
|
||||
}
|
||||
}
|
||||
46
internal/db/scan_rows.go
Normal file
46
internal/db/scan_rows.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil || len(colTypes) != len(columns) {
|
||||
colTypes = nil
|
||||
}
|
||||
|
||||
resultData := make([]map[string]interface{}, 0)
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{}, len(columns))
|
||||
for i, col := range columns {
|
||||
dbTypeName := ""
|
||||
if colTypes != nil && i < len(colTypes) && colTypes[i] != nil {
|
||||
dbTypeName = colTypes[i].DatabaseTypeName()
|
||||
}
|
||||
entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName)
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return resultData, columns, err
|
||||
}
|
||||
return resultData, columns, nil
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -52,6 +54,20 @@ func (s *SQLiteDB) Ping() error {
|
||||
return s.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -62,33 +78,18 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := s.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
entry[col] = normalizeQueryValue(values[i])
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) Exec(query string) (int64, error) {
|
||||
@@ -137,21 +138,336 @@ func (s *SQLiteDB) GetCreateStatement(dbName, tableName string) (string, error)
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return []connection.ColumnDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
|
||||
// cid, name, type, notnull, dflt_value, pk
|
||||
data, _, err := s.Query(fmt.Sprintf("PRAGMA table_info('%s')", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
getStr := func(row map[string]interface{}, key string) string {
|
||||
if v, ok := row[key]; ok && v != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
if v, ok := row[strings.ToUpper(key)]; ok && v != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
notnull := 0
|
||||
if v, ok := row["notnull"]; ok && v != nil {
|
||||
notnull = parseInt(v)
|
||||
} else if v, ok := row["NOTNULL"]; ok && v != nil {
|
||||
notnull = parseInt(v)
|
||||
}
|
||||
|
||||
pk := 0
|
||||
if v, ok := row["pk"]; ok && v != nil {
|
||||
pk = parseInt(v)
|
||||
} else if v, ok := row["PK"]; ok && v != nil {
|
||||
pk = parseInt(v)
|
||||
}
|
||||
|
||||
nullable := "YES"
|
||||
if notnull == 1 {
|
||||
nullable = "NO"
|
||||
}
|
||||
|
||||
key := ""
|
||||
if pk == 1 {
|
||||
key = "PRI"
|
||||
}
|
||||
|
||||
col := connection.ColumnDefinition{
|
||||
Name: getStr(row, "name"),
|
||||
Type: getStr(row, "type"),
|
||||
Nullable: nullable,
|
||||
Key: key,
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if v, ok := row["dflt_value"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
} else if v, ok := row["DFLT_VALUE"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return []connection.IndexDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
data, _, err := s.Query(fmt.Sprintf("PRAGMA index_list('%s')", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
indexName := ""
|
||||
if v, ok := row["name"]; ok && v != nil {
|
||||
indexName = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["NAME"]; ok && v != nil {
|
||||
indexName = fmt.Sprintf("%v", v)
|
||||
}
|
||||
if strings.TrimSpace(indexName) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
unique := 0
|
||||
if v, ok := row["unique"]; ok && v != nil {
|
||||
unique = parseInt(v)
|
||||
} else if v, ok := row["UNIQUE"]; ok && v != nil {
|
||||
unique = parseInt(v)
|
||||
}
|
||||
nonUnique := 1
|
||||
if unique == 1 {
|
||||
nonUnique = 0
|
||||
}
|
||||
|
||||
cols, _, err := s.Query(fmt.Sprintf("PRAGMA index_info('%s')", esc(indexName)))
|
||||
if err != nil {
|
||||
// skip broken index
|
||||
continue
|
||||
}
|
||||
|
||||
for _, c := range cols {
|
||||
colName := ""
|
||||
if v, ok := c["name"]; ok && v != nil {
|
||||
colName = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := c["NAME"]; ok && v != nil {
|
||||
colName = fmt.Sprintf("%v", v)
|
||||
}
|
||||
if strings.TrimSpace(colName) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if v, ok := c["seqno"]; ok && v != nil {
|
||||
seq = parseInt(v) + 1
|
||||
} else if v, ok := c["SEQNO"]; ok && v != nil {
|
||||
seq = parseInt(v) + 1
|
||||
}
|
||||
|
||||
indexes = append(indexes, connection.IndexDefinition{
|
||||
Name: indexName,
|
||||
ColumnName: colName,
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: "BTREE",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
|
||||
data, _, err := s.Query(fmt.Sprintf("PRAGMA foreign_key_list('%s')", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
id := 0
|
||||
if v, ok := row["id"]; ok && v != nil {
|
||||
id = parseInt(v)
|
||||
} else if v, ok := row["ID"]; ok && v != nil {
|
||||
id = parseInt(v)
|
||||
}
|
||||
|
||||
refTable := ""
|
||||
if v, ok := row["table"]; ok && v != nil {
|
||||
refTable = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["TABLE"]; ok && v != nil {
|
||||
refTable = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
fromCol := ""
|
||||
if v, ok := row["from"]; ok && v != nil {
|
||||
fromCol = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["FROM"]; ok && v != nil {
|
||||
fromCol = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
toCol := ""
|
||||
if v, ok := row["to"]; ok && v != nil {
|
||||
toCol = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["TO"]; ok && v != nil {
|
||||
toCol = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("fk_%s_%d", table, id)
|
||||
fks = append(fks, connection.ForeignKeyDefinition{
|
||||
Name: name,
|
||||
ColumnName: fromCol,
|
||||
RefTableName: refTable,
|
||||
RefColumnName: toCol,
|
||||
ConstraintName: name,
|
||||
})
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
|
||||
data, _, err := s.Query(fmt.Sprintf("SELECT name AS trigger_name, sql AS statement FROM sqlite_master WHERE type='trigger' AND tbl_name='%s' ORDER BY name", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
name := fmt.Sprintf("%v", row["trigger_name"])
|
||||
stmt := ""
|
||||
if v, ok := row["statement"]; ok && v != nil {
|
||||
stmt = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(stmt)
|
||||
timing := ""
|
||||
switch {
|
||||
case strings.Contains(upper, " BEFORE "):
|
||||
timing = "BEFORE"
|
||||
case strings.Contains(upper, " AFTER "):
|
||||
timing = "AFTER"
|
||||
case strings.Contains(upper, " INSTEAD OF "):
|
||||
timing = "INSTEAD OF"
|
||||
}
|
||||
|
||||
event := ""
|
||||
switch {
|
||||
case strings.Contains(upper, " INSERT "):
|
||||
event = "INSERT"
|
||||
case strings.Contains(upper, " UPDATE "):
|
||||
event = "UPDATE"
|
||||
case strings.Contains(upper, " DELETE "):
|
||||
event = "DELETE"
|
||||
}
|
||||
|
||||
triggers = append(triggers, connection.TriggerDefinition{
|
||||
Name: name,
|
||||
Timing: timing,
|
||||
Event: event,
|
||||
Statement: stmt,
|
||||
})
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
return []connection.ColumnDefinitionWithTable{}, nil
|
||||
tables, err := s.GetTables(dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, table := range tables {
|
||||
// Skip internal tables
|
||||
if strings.HasPrefix(strings.ToLower(table), "sqlite_") {
|
||||
continue
|
||||
}
|
||||
columns, err := s.GetColumns("", table)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, col := range columns {
|
||||
cols = append(cols, connection.ColumnDefinitionWithTable{
|
||||
TableName: table,
|
||||
Name: col.Name,
|
||||
Type: col.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
90
internal/redis/redis.go
Normal file
90
internal/redis/redis.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package redis
|
||||
|
||||
import "GoNavi-Wails/internal/connection"
|
||||
|
||||
// RedisValue represents a Redis value with its type and metadata
|
||||
type RedisValue struct {
|
||||
Type string `json:"type"` // string, hash, list, set, zset
|
||||
TTL int64 `json:"ttl"` // TTL in seconds, -1 means no expiry, -2 means key doesn't exist
|
||||
Value interface{} `json:"value"` // The actual value
|
||||
Length int64 `json:"length"` // Length/size of the value
|
||||
}
|
||||
|
||||
// RedisDBInfo represents information about a Redis database
|
||||
type RedisDBInfo struct {
|
||||
Index int `json:"index"` // Database index (0-15)
|
||||
Keys int64 `json:"keys"` // Number of keys in this database
|
||||
}
|
||||
|
||||
// RedisKeyInfo represents information about a Redis key
|
||||
type RedisKeyInfo struct {
|
||||
Key string `json:"key"`
|
||||
Type string `json:"type"`
|
||||
TTL int64 `json:"ttl"`
|
||||
}
|
||||
|
||||
// RedisScanResult represents the result of a SCAN operation
|
||||
type RedisScanResult struct {
|
||||
Keys []RedisKeyInfo `json:"keys"`
|
||||
Cursor uint64 `json:"cursor"`
|
||||
}
|
||||
|
||||
// RedisClient defines the interface for Redis operations
|
||||
type RedisClient interface {
|
||||
// Connection management
|
||||
Connect(config connection.ConnectionConfig) error
|
||||
Close() error
|
||||
Ping() error
|
||||
|
||||
// Key operations
|
||||
ScanKeys(pattern string, cursor uint64, count int64) (*RedisScanResult, error)
|
||||
GetKeyType(key string) (string, error)
|
||||
GetTTL(key string) (int64, error)
|
||||
SetTTL(key string, ttl int64) error
|
||||
DeleteKeys(keys []string) (int64, error)
|
||||
RenameKey(oldKey, newKey string) error
|
||||
KeyExists(key string) (bool, error)
|
||||
|
||||
// Value operations
|
||||
GetValue(key string) (*RedisValue, error)
|
||||
|
||||
// String operations
|
||||
GetString(key string) (string, error)
|
||||
SetString(key, value string, ttl int64) error
|
||||
|
||||
// Hash operations
|
||||
GetHash(key string) (map[string]string, error)
|
||||
SetHashField(key, field, value string) error
|
||||
DeleteHashField(key string, fields ...string) error
|
||||
|
||||
// List operations
|
||||
GetList(key string, start, stop int64) ([]string, error)
|
||||
ListPush(key string, values ...string) error
|
||||
ListSet(key string, index int64, value string) error
|
||||
|
||||
// Set operations
|
||||
GetSet(key string) ([]string, error)
|
||||
SetAdd(key string, members ...string) error
|
||||
SetRemove(key string, members ...string) error
|
||||
|
||||
// Sorted Set operations
|
||||
GetZSet(key string, start, stop int64) ([]ZSetMember, error)
|
||||
ZSetAdd(key string, members ...ZSetMember) error
|
||||
ZSetRemove(key string, members ...string) error
|
||||
|
||||
// Command execution
|
||||
ExecuteCommand(args []string) (interface{}, error)
|
||||
|
||||
// Server information
|
||||
GetServerInfo() (map[string]string, error)
|
||||
GetDatabases() ([]RedisDBInfo, error)
|
||||
SelectDB(index int) error
|
||||
GetCurrentDB() int
|
||||
FlushDB() error
|
||||
}
|
||||
|
||||
// ZSetMember represents a member in a sorted set
|
||||
type ZSetMember struct {
|
||||
Member string `json:"member"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
711
internal/redis/redis_impl.go
Normal file
711
internal/redis/redis_impl.go
Normal file
@@ -0,0 +1,711 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisClientImpl implements RedisClient using go-redis
|
||||
type RedisClientImpl struct {
|
||||
client *redis.Client
|
||||
config connection.ConnectionConfig
|
||||
currentDB int
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
// NewRedisClient creates a new Redis client instance
|
||||
func NewRedisClient() RedisClient {
|
||||
return &RedisClientImpl{}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to Redis
|
||||
func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
|
||||
r.config = config
|
||||
r.currentDB = config.RedisDB
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
// Handle SSH tunnel if enabled
|
||||
if config.UseSSH {
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败: %w", err)
|
||||
}
|
||||
r.forwarder = forwarder
|
||||
addr = forwarder.LocalAddr
|
||||
logger.Infof("Redis 通过 SSH 隧道连接: %s -> %s:%d", addr, config.Host, config.Port)
|
||||
}
|
||||
|
||||
opts := &redis.Options{
|
||||
Addr: addr,
|
||||
Password: config.Password,
|
||||
DB: config.RedisDB,
|
||||
DialTimeout: time.Duration(config.Timeout) * time.Second,
|
||||
ReadTimeout: time.Duration(config.Timeout) * time.Second,
|
||||
WriteTimeout: time.Duration(config.Timeout) * time.Second,
|
||||
}
|
||||
|
||||
if opts.DialTimeout == 0 {
|
||||
opts.DialTimeout = 30 * time.Second
|
||||
opts.ReadTimeout = 30 * time.Second
|
||||
opts.WriteTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
r.client = redis.NewClient(opts)
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opts.DialTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := r.client.Ping(ctx).Err(); err != nil {
|
||||
r.client.Close()
|
||||
r.client = nil
|
||||
return fmt.Errorf("Redis 连接失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("Redis 连接成功: %s DB=%d", addr, config.RedisDB)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the Redis connection
|
||||
func (r *RedisClientImpl) Close() error {
|
||||
if r.client != nil {
|
||||
err := r.client.Close()
|
||||
r.client = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping tests the connection
|
||||
func (r *RedisClientImpl) Ping() error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return r.client.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// ScanKeys scans keys matching a pattern
|
||||
func (r *RedisClientImpl) ScanKeys(pattern string, cursor uint64, count int64) (*RedisScanResult, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
if count <= 0 {
|
||||
count = 100
|
||||
}
|
||||
|
||||
keys, nextCursor, err := r.client.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &RedisScanResult{
|
||||
Keys: make([]RedisKeyInfo, 0, len(keys)),
|
||||
Cursor: nextCursor,
|
||||
}
|
||||
|
||||
// Get type and TTL for each key
|
||||
pipe := r.client.Pipeline()
|
||||
typeResults := make([]*redis.StatusCmd, len(keys))
|
||||
ttlResults := make([]*redis.DurationCmd, len(keys))
|
||||
|
||||
for i, key := range keys {
|
||||
typeResults[i] = pipe.Type(ctx, key)
|
||||
ttlResults[i] = pipe.TTL(ctx, key)
|
||||
}
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil && err != redis.Nil {
|
||||
// Fallback: get info one by one
|
||||
for _, key := range keys {
|
||||
keyType, _ := r.GetKeyType(key)
|
||||
ttl, _ := r.GetTTL(key)
|
||||
result.Keys = append(result.Keys, RedisKeyInfo{
|
||||
Key: key,
|
||||
Type: keyType,
|
||||
TTL: ttl,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for i, key := range keys {
|
||||
keyType := typeResults[i].Val()
|
||||
ttl := int64(ttlResults[i].Val().Seconds())
|
||||
if ttlResults[i].Val() == -1 {
|
||||
ttl = -1
|
||||
} else if ttlResults[i].Val() == -2 {
|
||||
ttl = -2
|
||||
}
|
||||
result.Keys = append(result.Keys, RedisKeyInfo{
|
||||
Key: key,
|
||||
Type: keyType,
|
||||
TTL: ttl,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetKeyType returns the type of a key
|
||||
func (r *RedisClientImpl) GetKeyType(key string) (string, error) {
|
||||
if r.client == nil {
|
||||
return "", fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return r.client.Type(ctx, key).Result()
|
||||
}
|
||||
|
||||
// GetTTL returns the TTL of a key in seconds
|
||||
func (r *RedisClientImpl) GetTTL(key string) (int64, error) {
|
||||
if r.client == nil {
|
||||
return 0, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ttl, err := r.client.TTL(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if ttl == -1 {
|
||||
return -1, nil // No expiry
|
||||
} else if ttl == -2 {
|
||||
return -2, nil // Key doesn't exist
|
||||
}
|
||||
return int64(ttl.Seconds()), nil
|
||||
}
|
||||
|
||||
// SetTTL sets the TTL of a key
|
||||
func (r *RedisClientImpl) SetTTL(key string, ttl int64) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if ttl < 0 {
|
||||
// Remove expiry
|
||||
return r.client.Persist(ctx, key).Err()
|
||||
}
|
||||
return r.client.Expire(ctx, key, time.Duration(ttl)*time.Second).Err()
|
||||
}
|
||||
|
||||
// DeleteKeys deletes one or more keys
|
||||
func (r *RedisClientImpl) DeleteKeys(keys []string) (int64, error) {
|
||||
if r.client == nil {
|
||||
return 0, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.client.Del(ctx, keys...).Result()
|
||||
}
|
||||
|
||||
// RenameKey renames a key
|
||||
func (r *RedisClientImpl) RenameKey(oldKey, newKey string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return r.client.Rename(ctx, oldKey, newKey).Err()
|
||||
}
|
||||
|
||||
// KeyExists checks if a key exists
|
||||
func (r *RedisClientImpl) KeyExists(key string) (bool, error) {
|
||||
if r.client == nil {
|
||||
return false, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
n, err := r.client.Exists(ctx, key).Result()
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
// GetValue gets the value of a key with automatic type detection
|
||||
func (r *RedisClientImpl) GetValue(key string) (*RedisValue, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
|
||||
keyType, err := r.GetKeyType(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ttl, _ := r.GetTTL(key)
|
||||
|
||||
result := &RedisValue{
|
||||
Type: keyType,
|
||||
TTL: ttl,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch keyType {
|
||||
case "string":
|
||||
val, err := r.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Value = val
|
||||
result.Length = int64(len(val))
|
||||
|
||||
case "hash":
|
||||
val, err := r.client.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Value = val
|
||||
result.Length = int64(len(val))
|
||||
|
||||
case "list":
|
||||
length, err := r.client.LLen(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Get first 1000 items
|
||||
limit := int64(1000)
|
||||
if length < limit {
|
||||
limit = length
|
||||
}
|
||||
val, err := r.client.LRange(ctx, key, 0, limit-1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Value = val
|
||||
result.Length = length
|
||||
|
||||
case "set":
|
||||
length, err := r.client.SCard(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Get members using SMembers (limited by Redis server)
|
||||
members, err := r.client.SMembers(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Value = members
|
||||
result.Length = length
|
||||
|
||||
case "zset":
|
||||
length, err := r.client.ZCard(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Get first 1000 members with scores
|
||||
limit := int64(1000)
|
||||
if length < limit {
|
||||
limit = length
|
||||
}
|
||||
val, err := r.client.ZRangeWithScores(ctx, key, 0, limit-1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
members := make([]ZSetMember, len(val))
|
||||
for i, z := range val {
|
||||
members[i] = ZSetMember{
|
||||
Member: z.Member.(string),
|
||||
Score: z.Score,
|
||||
}
|
||||
}
|
||||
result.Value = members
|
||||
result.Length = length
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的 Redis 数据类型: %s", keyType)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetString gets a string value
|
||||
func (r *RedisClientImpl) GetString(key string) (string, error) {
|
||||
if r.client == nil {
|
||||
return "", fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return r.client.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
// SetString sets a string value with optional TTL
|
||||
func (r *RedisClientImpl) SetString(key, value string, ttl int64) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var expiration time.Duration
|
||||
if ttl > 0 {
|
||||
expiration = time.Duration(ttl) * time.Second
|
||||
}
|
||||
return r.client.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
// GetHash gets all fields of a hash
|
||||
func (r *RedisClientImpl) GetHash(key string) (map[string]string, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.client.HGetAll(ctx, key).Result()
|
||||
}
|
||||
|
||||
// SetHashField sets a field in a hash
|
||||
func (r *RedisClientImpl) SetHashField(key, field, value string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return r.client.HSet(ctx, key, field, value).Err()
|
||||
}
|
||||
|
||||
// DeleteHashField deletes fields from a hash
|
||||
func (r *RedisClientImpl) DeleteHashField(key string, fields ...string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return r.client.HDel(ctx, key, fields...).Err()
|
||||
}
|
||||
|
||||
// GetList gets a range of elements from a list
|
||||
func (r *RedisClientImpl) GetList(key string, start, stop int64) ([]string, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.client.LRange(ctx, key, start, stop).Result()
|
||||
}
|
||||
|
||||
// ListPush pushes values to the end of a list
|
||||
func (r *RedisClientImpl) ListPush(key string, values ...string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
args := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
args[i] = v
|
||||
}
|
||||
return r.client.RPush(ctx, key, args...).Err()
|
||||
}
|
||||
|
||||
// ListSet sets the value at an index in a list
|
||||
func (r *RedisClientImpl) ListSet(key string, index int64, value string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return r.client.LSet(ctx, key, index, value).Err()
|
||||
}
|
||||
|
||||
// GetSet gets all members of a set
|
||||
func (r *RedisClientImpl) GetSet(key string) ([]string, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.client.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
// SetAdd adds members to a set
|
||||
func (r *RedisClientImpl) SetAdd(key string, members ...string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
args := make([]interface{}, len(members))
|
||||
for i, m := range members {
|
||||
args[i] = m
|
||||
}
|
||||
return r.client.SAdd(ctx, key, args...).Err()
|
||||
}
|
||||
|
||||
// SetRemove removes members from a set
|
||||
func (r *RedisClientImpl) SetRemove(key string, members ...string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
args := make([]interface{}, len(members))
|
||||
for i, m := range members {
|
||||
args[i] = m
|
||||
}
|
||||
return r.client.SRem(ctx, key, args...).Err()
|
||||
}
|
||||
|
||||
// GetZSet gets members with scores from a sorted set
|
||||
func (r *RedisClientImpl) GetZSet(key string, start, stop int64) ([]ZSetMember, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
val, err := r.client.ZRangeWithScores(ctx, key, start, stop).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
members := make([]ZSetMember, len(val))
|
||||
for i, z := range val {
|
||||
members[i] = ZSetMember{
|
||||
Member: z.Member.(string),
|
||||
Score: z.Score,
|
||||
}
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// ZSetAdd adds members to a sorted set
|
||||
func (r *RedisClientImpl) ZSetAdd(key string, members ...ZSetMember) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
zMembers := make([]redis.Z, len(members))
|
||||
for i, m := range members {
|
||||
zMembers[i] = redis.Z{
|
||||
Score: m.Score,
|
||||
Member: m.Member,
|
||||
}
|
||||
}
|
||||
return r.client.ZAdd(ctx, key, zMembers...).Err()
|
||||
}
|
||||
|
||||
// ZSetRemove removes members from a sorted set
|
||||
func (r *RedisClientImpl) ZSetRemove(key string, members ...string) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
args := make([]interface{}, len(members))
|
||||
for i, m := range members {
|
||||
args[i] = m
|
||||
}
|
||||
return r.client.ZRem(ctx, key, args...).Err()
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a raw Redis command
|
||||
func (r *RedisClientImpl) ExecuteCommand(args []string) (interface{}, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("命令不能为空")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Convert to []interface{}
|
||||
cmdArgs := make([]interface{}, len(args))
|
||||
for i, arg := range args {
|
||||
cmdArgs[i] = arg
|
||||
}
|
||||
|
||||
result, err := r.client.Do(ctx, cmdArgs...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return formatCommandResult(result), nil
|
||||
}
|
||||
|
||||
// formatCommandResult formats the command result for display
|
||||
func formatCommandResult(result interface{}) interface{} {
|
||||
switch v := result.(type) {
|
||||
case []interface{}:
|
||||
formatted := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
formatted[i] = formatCommandResult(item)
|
||||
}
|
||||
return formatted
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerInfo returns server information
|
||||
func (r *RedisClientImpl) GetServerInfo() (map[string]string, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
info, err := r.client.Info(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
lines := strings.Split(info, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
result[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetDatabases returns information about all databases
|
||||
func (r *RedisClientImpl) GetDatabases() ([]RedisDBInfo, error) {
|
||||
if r.client == nil {
|
||||
return nil, fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Get keyspace info
|
||||
info, err := r.client.Info(ctx, "keyspace").Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse keyspace info
|
||||
dbMap := make(map[int]int64)
|
||||
lines := strings.Split(info, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "db") {
|
||||
// Format: db0:keys=123,expires=0,avg_ttl=0
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
dbIndex, err := strconv.Atoi(strings.TrimPrefix(parts[0], "db"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Parse keys count
|
||||
kvPairs := strings.Split(parts[1], ",")
|
||||
for _, kv := range kvPairs {
|
||||
if strings.HasPrefix(kv, "keys=") {
|
||||
keys, _ := strconv.ParseInt(strings.TrimPrefix(kv, "keys="), 10, 64)
|
||||
dbMap[dbIndex] = keys
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return all 16 databases (0-15)
|
||||
result := make([]RedisDBInfo, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
result[i] = RedisDBInfo{
|
||||
Index: i,
|
||||
Keys: dbMap[i], // Will be 0 if not in map
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SelectDB selects a database
|
||||
func (r *RedisClientImpl) SelectDB(index int) error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
if index < 0 || index > 15 {
|
||||
return fmt.Errorf("数据库索引必须在 0-15 之间")
|
||||
}
|
||||
|
||||
// Create new client with different DB
|
||||
addr := fmt.Sprintf("%s:%d", r.config.Host, r.config.Port)
|
||||
if r.forwarder != nil {
|
||||
addr = r.forwarder.LocalAddr
|
||||
}
|
||||
|
||||
opts := &redis.Options{
|
||||
Addr: addr,
|
||||
Password: r.config.Password,
|
||||
DB: index,
|
||||
DialTimeout: time.Duration(r.config.Timeout) * time.Second,
|
||||
ReadTimeout: time.Duration(r.config.Timeout) * time.Second,
|
||||
WriteTimeout: time.Duration(r.config.Timeout) * time.Second,
|
||||
}
|
||||
|
||||
if opts.DialTimeout == 0 {
|
||||
opts.DialTimeout = 30 * time.Second
|
||||
opts.ReadTimeout = 30 * time.Second
|
||||
opts.WriteTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
newClient := redis.NewClient(opts)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opts.DialTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := newClient.Ping(ctx).Err(); err != nil {
|
||||
newClient.Close()
|
||||
return fmt.Errorf("切换数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// Close old client and replace
|
||||
r.client.Close()
|
||||
r.client = newClient
|
||||
r.currentDB = index
|
||||
|
||||
logger.Infof("Redis 切换到数据库: db%d", index)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentDB returns the current database index
|
||||
func (r *RedisClientImpl) GetCurrentDB() int {
|
||||
return r.currentDB
|
||||
}
|
||||
|
||||
// FlushDB flushes the current database
|
||||
func (r *RedisClientImpl) FlushDB() error {
|
||||
if r.client == nil {
|
||||
return fmt.Errorf("Redis 客户端未连接")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.client.FlushDB(ctx).Err()
|
||||
}
|
||||
@@ -3,8 +3,10 @@ package ssh
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -110,3 +112,264 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
|
||||
|
||||
return netName, nil
|
||||
}
|
||||
|
||||
// sshClientCache stores SSH clients to avoid creating multiple connections
|
||||
var (
|
||||
sshClientCache = make(map[string]*ssh.Client)
|
||||
sshClientCacheMu sync.RWMutex
|
||||
localForwarders = make(map[string]*LocalForwarder)
|
||||
forwarderMu sync.RWMutex
|
||||
)
|
||||
|
||||
// LocalForwarder represents a local port forwarder through SSH
|
||||
type LocalForwarder struct {
|
||||
LocalAddr string
|
||||
RemoteAddr string
|
||||
SSHClient *ssh.Client
|
||||
listener net.Listener
|
||||
closeChan chan struct{}
|
||||
closeOnce sync.Once // 防止重复关闭
|
||||
closed bool // 关闭状态标记
|
||||
closedMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewLocalForwarder creates a new local port forwarder
|
||||
// It listens on a random local port and forwards all connections through SSH tunnel
|
||||
func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
|
||||
client, err := GetOrCreateSSHClient(sshConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
|
||||
}
|
||||
|
||||
// Listen on localhost with a random port
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建本地监听器失败:%w", err)
|
||||
}
|
||||
|
||||
localAddr := listener.Addr().String()
|
||||
remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)
|
||||
|
||||
forwarder := &LocalForwarder{
|
||||
LocalAddr: localAddr,
|
||||
RemoteAddr: remoteAddr,
|
||||
SSHClient: client,
|
||||
listener: listener,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start forwarding in background
|
||||
go forwarder.forward()
|
||||
|
||||
logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr)
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
// forward handles the port forwarding
|
||||
func (f *LocalForwarder) forward() {
|
||||
for {
|
||||
localConn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
// Check if we're shutting down
|
||||
select {
|
||||
case <-f.closeChan:
|
||||
return
|
||||
default:
|
||||
logger.Warnf("接受本地连接失败:%v", err)
|
||||
// listener可能已关闭,退出循环
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go f.handleConnection(localConn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single connection
|
||||
func (f *LocalForwarder) handleConnection(localConn net.Conn) {
|
||||
defer localConn.Close()
|
||||
|
||||
// Connect to remote through SSH with timeout
|
||||
remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr)
|
||||
if err != nil {
|
||||
logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
|
||||
// Bidirectional copy with error channel
|
||||
errc := make(chan error, 2)
|
||||
|
||||
// Copy from local to remote
|
||||
go func() {
|
||||
_, err := io.Copy(remoteConn, localConn)
|
||||
if err != nil {
|
||||
logger.Warnf("本地->远程数据复制错误:%v", err)
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
// Copy from remote to local
|
||||
go func() {
|
||||
_, err := io.Copy(localConn, remoteConn)
|
||||
if err != nil {
|
||||
logger.Warnf("远程->本地数据复制错误:%v", err)
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
// Wait for BOTH goroutines to complete
|
||||
<-errc
|
||||
<-errc
|
||||
}
|
||||
|
||||
// Close closes the forwarder (thread-safe, can be called multiple times)
|
||||
func (f *LocalForwarder) Close() error {
|
||||
var err error
|
||||
f.closeOnce.Do(func() {
|
||||
f.closedMu.Lock()
|
||||
f.closed = true
|
||||
f.closedMu.Unlock()
|
||||
|
||||
close(f.closeChan)
|
||||
err = f.listener.Close()
|
||||
if err != nil {
|
||||
logger.Warnf("关闭端口转发监听器失败:%v", err)
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// IsClosed returns whether the forwarder is closed
|
||||
func (f *LocalForwarder) IsClosed() bool {
|
||||
f.closedMu.RLock()
|
||||
defer f.closedMu.RUnlock()
|
||||
return f.closed
|
||||
}
|
||||
|
||||
// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one
|
||||
func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
|
||||
key := fmt.Sprintf("%s:%d:%s->%s:%d",
|
||||
sshConfig.Host, sshConfig.Port, sshConfig.User,
|
||||
remoteHost, remotePort)
|
||||
|
||||
forwarderMu.RLock()
|
||||
forwarder, exists := localForwarders[key]
|
||||
forwarderMu.RUnlock()
|
||||
|
||||
// Check if exists and is still valid
|
||||
if exists && forwarder != nil && !forwarder.IsClosed() {
|
||||
logger.Infof("复用已有端口转发:%s", key)
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
// Remove stale forwarder from cache
|
||||
if exists {
|
||||
forwarderMu.Lock()
|
||||
delete(localForwarders, key)
|
||||
forwarderMu.Unlock()
|
||||
}
|
||||
|
||||
forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarderMu.Lock()
|
||||
localForwarders[key] = forwarder
|
||||
forwarderMu.Unlock()
|
||||
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
// CloseAllForwarders closes all local forwarders
|
||||
func CloseAllForwarders() {
|
||||
forwarderMu.Lock()
|
||||
defer forwarderMu.Unlock()
|
||||
|
||||
for key, forwarder := range localForwarders {
|
||||
if forwarder != nil {
|
||||
_ = forwarder.Close()
|
||||
logger.Infof("已关闭端口转发:%s", key)
|
||||
}
|
||||
}
|
||||
localForwarders = make(map[string]*LocalForwarder)
|
||||
}
|
||||
|
||||
|
||||
// getSSHClientCacheKey generates a unique cache key for SSH config
|
||||
func getSSHClientCacheKey(config connection.SSHConfig) string {
|
||||
return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User)
|
||||
}
|
||||
|
||||
// GetOrCreateSSHClient returns a cached SSH client or creates a new one
|
||||
func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) {
|
||||
key := getSSHClientCacheKey(config)
|
||||
|
||||
sshClientCacheMu.RLock()
|
||||
client, exists := sshClientCache[key]
|
||||
sshClientCacheMu.RUnlock()
|
||||
|
||||
if exists && client != nil {
|
||||
// Test if connection is still alive by creating a test session
|
||||
session, err := client.NewSession()
|
||||
if err == nil {
|
||||
session.Close()
|
||||
logger.Infof("复用已有 SSH 连接:%s", key)
|
||||
return client, nil
|
||||
}
|
||||
// Connection is dead, remove from cache
|
||||
logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err)
|
||||
sshClientCacheMu.Lock()
|
||||
delete(sshClientCache, key)
|
||||
sshClientCacheMu.Unlock()
|
||||
// Try to close the dead client
|
||||
_ = client.Close()
|
||||
}
|
||||
|
||||
// Create new SSH client
|
||||
client, err := connectSSH(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the client
|
||||
sshClientCacheMu.Lock()
|
||||
sshClientCache[key] = client
|
||||
sshClientCacheMu.Unlock()
|
||||
|
||||
logger.Infof("已缓存 SSH 连接:%s", key)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// DialThroughSSH creates a connection through SSH tunnel
|
||||
// This is a generic dialer that can be used by any database driver
|
||||
func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) {
|
||||
client, err := GetOrCreateSSHClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
|
||||
}
|
||||
|
||||
conn, err := client.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err)
|
||||
}
|
||||
|
||||
logger.Infof("已通过 SSH 隧道连接到:%s", address)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// CloseAllSSHClients closes all cached SSH clients
|
||||
func CloseAllSSHClients() {
|
||||
sshClientCacheMu.Lock()
|
||||
defer sshClientCacheMu.Unlock()
|
||||
|
||||
for key, client := range sshClientCache {
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
logger.Infof("已关闭 SSH 连接:%s", key)
|
||||
}
|
||||
}
|
||||
sshClientCache = make(map[string]*ssh.Client)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user