Compare commits

...

45 Commits

Author SHA1 Message Date
Syngnat
13ba78103c feat(frontend/backend): 批量操作与表格编辑增强并完善事务支持
- 批量导出/备份:表与数据库支持全选/反选/智能上下文
  - 右键菜单:单元格菜单支持设置 NULL
  - 编辑优化:大字段弹窗、仅值变化标记、提交只发送差异字段
  - 事务支持:PostgreSQL/SQLite/Oracle/DaMeng/KingBase ApplyChanges
  - MySQL 修复:提交前归一化 datetime,避免写入失败
  - 性能优化:移除 activeCell 重渲染、useRef 存储选中节点、防重加载
  - Redis 优化:二进制智能解码与视图模式切换
  - 资源更新:替换前端 favicon/logo
2026-02-05 14:30:05 +08:00
Syngnat
538e4a1506 Merge pull request #70 from bengbengbalabalabeng/feat-issues-55
ci: add publish-to-winget action
2026-02-05 08:41:48 +08:00
Syngnat
934581c796 chore(ci): 调整 WinGet 发布配置
## 修改内容
- 修正 WinGet workflow 中 installers-regex,使其匹配实际 Release 产物名称

## 修改原因
- 原匹配规则无法匹配 GoNavi-windows-amd64.exe / GoNavi-windows-arm64.exe
- 避免 WinGet 发布流程找不到安装包导致失败

## 影响范围
- CI / WinGet 发布流程
2026-02-05 08:41:18 +08:00
baicaixiaozhan
1486b98d27 ci: add publish-to-winget action 2026-02-04 20:02:43 +08:00
Syngnat
6cda430f03 🔧 chore(ci/build): 移除Linux ARM64构建支持以简化发布流程
- 从构建矩阵中移除linux/arm64平台
  - 移除ARM64交叉编译工具链安装逻辑
  - 简化Linux依赖安装流程,移除条件判断
  - 保留macOS和Windows的ARM64支持(原生构建)
  - 当前支持平台:macOS(AMD64/ARM64)、Windows(AMD64/ARM64)、Linux(AMD64)
  - 技术原因:Wails CGO交叉编译在x86_64 runner上存在头文件冲突问题
2026-02-04 17:50:13 +08:00
Syngnat
f56c3d5f6e 🐛 fix(workflows): 移除了 dpkg --add-architecture arm64,这会导致 apt 尝试从不存在的 ARM64 仓库获取包 2026-02-04 17:43:31 +08:00
Syngnat
74c9143c95 🐛 fix(workflows): 添加 wget 重试机制(3次重试,超时控制) 2026-02-04 17:36:59 +08:00
Syngnat
0e4a833ffa 🐛 fix(workflows): 修复artifact_name 冲突 2026-02-04 17:30:26 +08:00
Syngnat
37ad9885b7 Merge pull request #69 from Syngnat/release/0.3.0
🐛 fix(workflows): 修复actions语法错误
2026-02-04 17:19:46 +08:00
Syngnat
5cef9a4032 Merge pull request #68 from Syngnat/dev
🐛 fix(workflows): 修复actions语法错误
2026-02-04 17:18:54 +08:00
Syngnat
f49767c38b 🐛 fix(workflows): 修复actions语法错误 2026-02-04 17:17:02 +08:00
Syngnat
7e8699ba02 Merge pull request #67 from Syngnat/release/0.3.0
 feat(redis): 新增Redis数据源完整支持
2026-02-04 17:05:11 +08:00
Syngnat
5f0ce5ed7a Merge pull request #66 from Syngnat/feature/support-redis-20260204-ygf
 feat(redis): 新增Redis数据源完整支持
2026-02-04 17:03:40 +08:00
Syngnat
49c7620bdd 🐛 fix(redis/kingbase): Redis数据库选择优化与金仓标识符引号修复
- Redis配置优化:移除固定数据库输入框,改为测试连接后多选数据库
  - 数据库筛选:支持选择显示的Redis数据库(0-15),留空显示全部
  - 类型扩展:SavedConnection新增includeRedisDatabases字段存储用户选择
  - 侧边栏过滤:根据配置过滤显示的Redis数据库列表
  - 金仓修复:KingBase/PostgreSQL标识符仅在必要时加双引号
  - 保留字检测:新增needsQuote函数识别特殊字符和SQL保留字
2026-02-04 17:00:51 +08:00
Syngnat
80fa7a1acd feat(redis): 新增Redis数据源完整支持
- 后端实现:新增Redis客户端接口与go-redis实现,支持SSH隧道连接
  - API方法:新增21个Redis操作API(连接/Key/Value/命令执行等)
  - 连接配置:ConnectionModal支持Redis类型,自动识别端口与认证方式
  - 数据浏览:RedisViewer组件支持Key列表展示、类型识别与分页加载
  - 值编辑器:支持String/Hash/List/Set/ZSet五种数据类型的查看与编辑
  - 二进制处理:自动检测二进制数据并以十六进制格式展示
  - 命令终端:RedisCommandEditor支持多行命令执行与结果展示
  - 交互优化:JSON语法高亮编辑、一键复制值、面板宽度可调整
2026-02-04 16:45:51 +08:00
Syngnat
68770a42e2 Merge pull request #65 from Syngnat/feature/support-linux-windosw-arm-amd-20260204-ygf
 feat(ci/build): 新增Linux和Windows ARM64多平台构建支持
2026-02-04 15:15:18 +08:00
Syngnat
06aebf716e feat(ci/build): 新增Linux和Windows ARM64多平台构建支持
- CI矩阵扩展:新增Linux amd64/arm64和Windows arm64构建任务
  - AppImage支持:Linux平台生成通用AppImage包,兼容所有主流发行版
  - 依赖安装:自动安装GTK3/WebKit2GTK及ARM64交叉编译工具链
  - 本地构建:build-release.sh支持Linux/Windows多架构本地构建
  - 交叉编译:macOS/Linux可交叉编译其他平台,自动检测工具链
  - 打包优化:Linux输出tar.gz和AppImage两种格式
2026-02-04 15:02:42 +08:00
Syngnat
f551b19f40 Merge pull request #64 from Syngnat/release/0.2.6
♻️ refactor(database/ssh): SSH隧道架构重构与多数据源适配
2026-02-04 14:41:43 +08:00
Syngnat
6674ad69e1 Merge pull request #63 from Syngnat/dev
♻️ refactor(database/ssh): SSH隧道架构重构与多数据源适配
2026-02-04 14:40:34 +08:00
Syngnat
37d35684f1 Merge pull request #62 from Syngnat/feature/table-and-database-export-20260203-ygf
♻️ refactor(database/ssh): SSH隧道架构重构与多数据源适配
2026-02-04 14:37:11 +08:00
Syngnat
71e5de0cdc ♻️ refactor(database/ssh): SSH隧道架构重构与多数据源适配
- 架构升级:从driver专属拨号器改为通用本地端口转发模式
  - 并发安全:sync.Once保护Close操作,RWMutex保护状态访问,双向errc等待
  - 连接池化:GetOrCreateLocalForwarder/GetOrCreateSSHClient实现缓存复用
  - SQL安全:kingbase_impl.go引入esc函数,防止双引号注入(""ldf_server""问题)
  - Schema动态化:三级fallback(schema.table解析→dbName参数→current_schema())
  - 代码复用:scanRows统一行扫描逻辑,normalizeQueryValueWithDBType增强类型处理
  Close #40
2026-02-04 14:35:31 +08:00
Syngnat
d8656c6c9c 🐛 fix(query-editor): 修复别名字段不联想与启动编译报错
- a.<field> 场景根据 alias->table 提供字段补全
  - 修复 currentDbRef 重复声明(TS2451)
  - 保持原关键字/表名/字段补全行为不变
2026-02-04 12:37:30 +08:00
Syngnat
443b487a02 Merge pull request #60 from Syngnat/feature/0.2.5
Feature/0.2.5
2026-02-04 12:31:50 +08:00
Syngnat
bac57ebdf0 Merge pull request #59 from Syngnat/dev
🐛 fix(table): 修复虚拟表全选丢失并完善导出/筛选能力

- 表头自定义组件保留 width,virtual 模式下选择列正常显示
- 新增后端 ExportQuery,导出当前页/选中行避免长字段 IPC 截断
- 筛选支持更多操作符并统一 WHERE 生成逻辑
Close #57
Close #56

 feat(table-edit): 增加整行编辑面板,提升多字段/长文本编辑效率

- 支持选中行后一键打开编辑面板
- 全字段可编辑,长文本/JSON 友好输入与弹窗编辑
- 应用后写入本地变更,提交事务后落库

️ perf(table): 表数据打开加速,主键/统计等耗时操作异步化

- DataViewer 主键列元数据异步拉取,首屏数据优先渲染
- 查询页增加结果集最大行数限制,减少大表全量返回
- DBQuery 引入 Context 超时,降低长查询对 UI 的阻塞风险
- 查询行数设置持久化保存
Closes #48 

 feat(db-ui): 修复金仓打开表报错并增强结果页编辑体验

- postgres/kingbase 查询前自动清洗 ""ident"" 形式的非法标识符
- 结果表支持单元格弹窗编辑,提升 JSON/长文本可编辑性
- 修复查询结果表头与数据列宽度不对齐问题
Closes #49
2026-02-04 12:30:42 +08:00
Syngnat
213a33e4f3 Merge pull request #58 from Syngnat/feature/table-and-database-export-20260203-ygf
Feature/table and database export 20260203 ygf
2026-02-04 12:29:33 +08:00
Syngnat
a00f87582d 🐛 fix(table): 修复虚拟表全选丢失并完善导出/筛选能力
- 表头自定义组件保留 width,virtual 模式下选择列正常显示
  - 新增后端 ExportQuery,导出当前页/选中行避免长字段 IPC 截断
  - 筛选支持更多操作符并统一 WHERE 生成逻辑
  Close #57
  Close #56
2026-02-04 12:23:41 +08:00
Syngnat
f129623000 feat(table-edit): 增加整行编辑面板,提升多字段/长文本编辑效率
- 支持选中行后一键打开编辑面板
  - 全字段可编辑,长文本/JSON 友好输入与弹窗编辑
  - 应用后写入本地变更,提交事务后落库
2026-02-04 11:43:47 +08:00
Syngnat
8dbc97e466 ️ perf(table): 表数据打开加速,主键/统计等耗时操作异步化
- DataViewer 主键列元数据异步拉取,首屏数据优先渲染
  - 查询页增加结果集最大行数限制,减少大表全量返回
  - DBQuery 引入 Context 超时,降低长查询对 UI 的阻塞风险
  - 查询行数设置持久化保存
  Closes #48
  Closes #49
2026-02-04 11:01:28 +08:00
Syngnat
4a0db185c0 feat(db-ui): 修复金仓打开表报错并增强结果页编辑体验
- postgres/kingbase 查询前自动清洗 ""ident"" 形式的非法标识符
  - 结果表支持单元格弹窗编辑,提升 JSON/长文本可编辑性
  - 修复查询结果表头与数据列宽度不对齐问题
2026-02-04 10:13:02 +08:00
Syngnat
5793f63ac8 ️ optimize(core): 查询多语句多结果与大表交互/元数据体验优化
- 支持分号多语句拆分(含引号/注释/PG dollar-quote),多结果集 Tab 展示;
- 支持选中运行;结果 Tab 支持关闭
- 修复结果区高度自动收缩/最后一行裁剪;切换结果更顺滑(关闭 ink-bar 动画、修复隐藏面板叠加显示)
- 补齐 PostgreSQL/SQLite 设计表元数据接口;
- 修复 Kingbase schema/标识符引用导致打开表失败
- 标签页右键支持关闭其他/关闭左侧/关闭右侧/关闭所有
2026-02-03 22:48:24 +08:00
Syngnat
8aabc67634 Merge pull request #46 from Syngnat/feature/table-and-database-export-20260203-ygf
- 支持分号多语句拆分(含引号/注释/PG dollar-quote),多结果集 Tab 展示;
- 支持选中运行;结果 Tab 支持关闭
- 修复结果区高度自动收缩/最后一行裁剪;切换结果更顺滑(关闭 ink-bar 动画、修复隐藏面板叠加显示)
- 补齐 PostgreSQL/SQLite 设计表元数据接口;
- 修复 Kingbase schema/标识符引用导致打开表失败
- 标签页右键支持关闭其他/关闭左侧/关闭右侧/关闭所有
2026-02-03 22:46:25 +08:00
杨国锋
34c494ce51 ️ optimize(core): 查询多语句多结果与大表交互/元数据体验优化
- 支持分号多语句拆分(含引号/注释/PG dollar-quote),多结果集 Tab 展示;
  - 支持选中运行;结果 Tab 支持关闭
  - 修复结果区高度自动收缩/最后一行裁剪;切换结果更顺滑(关闭 ink-bar 动画、修复隐藏面板叠加显示)
  - 补齐 PostgreSQL/SQLite 设计表元数据接口;
  - 修复 Kingbase schema/标识符引用导致打开表失败
  - 标签页右键支持关闭其他/关闭左侧/关闭右侧/关闭所有
2026-02-03 22:44:48 +08:00
Syngnat
178de02783 Merge pull request #45 from bengbengbalabalabeng/chore-add-issues-templates
- 新增 issues template 以统一 issue 类型
2026-02-03 22:39:39 +08:00
baicaixiaozhan
94e5b8d2c6 chore: add Github issues templates 2026-02-03 21:49:43 +08:00
杨国锋
89e2247c05 feat(database): 增强库/表级导出与备份能力,优化侧边栏交互
- 数据库节点新增导出全部表结构/结构+数据 SQL(ExportDatabaseSQL)
  - 表节点支持多选/单选右键导出与备份(ExportTablesSQL)
  - ExportTable 支持导出 SQL(结构+数据)
  - 双击表仅打开表数据,不再触发展开/折叠
2026-02-03 19:49:04 +08:00
Syngnat
b2ede61b79 Merge pull request #43 from Syngnat/feature/0.2.3
️ perf(frontend): 大数据表格拖拽与打开加载性能、增加数据同步差异对比、行级选择
2026-02-03 19:23:49 +08:00
Syngnat
db381ae9d1 Merge pull request #42 from Syngnat/dev
️ perf(frontend): 大数据表格拖拽与打开加载性能、增加数据同步差异对比、行级选择
2026-02-03 19:23:15 +08:00
Syngnat
f946cfd647 Merge pull request #41 from Syngnat/feature/data-sync-optimization-20260203-ygf
️ perf(frontend): 大数据表格拖拽与打开加载性能、增加数据同步差异对比、行级选择
2026-02-03 19:21:29 +08:00
杨国锋
46c48c5ea8 ️ perf(frontend): 大数据表格拖拽与打开加载性能
- 列宽拖拽改为 rAF + transform 更新幽灵线,降低 mousemove 负载
- 大结果集自动启用 antd Table virtual 渲染,减少 DOM 压力
- 打开表改为先查数据,COUNT(*) 后台统计并回填分页总数,避免长时间 loading
- 统一内部 rowKey 字段 __gonavi_row_key__,避免与业务字段 key 冲突
2026-02-03 19:16:10 +08:00
杨国锋
e3bf160072 feat(sync): 数据同步支持差异对比、行级选择与实时进度日志
- 新增差异分析/预览接口与前端预览抽屉(插入/更新/删除)
  - 支持按表勾选插入/更新/删除(删除默认不勾选)
  - 支持按主键选择行级同步;无主键/复合主键表跳过并提示
  - 同步过程实时输出中文日志与进度条,便于定位失败步骤
2026-02-03 17:37:41 +08:00
Syngnat
791425a5a8 🐛 fix(db): 适配 schema/owner 限定名,修复 PG/金仓表不存在,修复表格数据显示异常
- 覆盖 mysql/postgres/kingbase/oracle/dameng/sqlite/custom 的 Query 返回值转换
- 修正可编辑表格保存范围,避免状态残留影响显示
- 表列表返回 schema.table/owner.table,避免 search_path 不一致导致 relation does not exist
- 元数据/导入导出/提交变更统一解析限定名并正确引用
- 前端查询与数据浏览支持限定名 quote
- 单元格编辑态时间字段统一显示为 YYYY-MM-DD HH:mm:ss
close #36
2026-02-03 14:39:05 +08:00
Syngnat
d7acfd1af9 🐛 fix(db): 适配 schema/owner 限定名,修复 PG/金仓表不存在,修复表格数据显示异常
- 覆盖 mysql/postgres/kingbase/oracle/dameng/sqlite/custom 的 Query 返回值转换
- 修正可编辑表格保存范围,避免状态残留影响显示
- 表列表返回 schema.table/owner.table,避免 search_path 不一致导致 relation does not exist
- 元数据/导入导出/提交变更统一解析限定名并正确引用
- 前端查询与数据浏览支持限定名 quote
- 单元格编辑态时间字段统一显示为 YYYY-MM-DD HH:mm:ss
close #36
2026-02-03 14:38:05 +08:00
Syngnat
80fbfd6365 Merge pull request #37 from Syngnat/feature/extend-datasource-and-sync-20250202-ygf
🐛 fix(db): 适配 schema/owner 限定名,修复 PG/金仓表不存在,修复表格数据显示异常
2026-02-03 14:35:13 +08:00
杨国锋
2ca27ebfb0 🐛 fix(query): 统一处理 []byte(nil) 为 NULL,修复表格数据显示异常
- 覆盖 mysql/postgres/kingbase/oracle/dameng/sqlite/custom 的 Query 返回值转换
  - 修正可编辑表格保存范围,避免状态残留影响显示
2026-02-03 14:27:10 +08:00
杨国锋
aa7651d95c 🐛 fix(db): 适配 schema/owner 限定名,修复 PG/金仓表不存在
- 表列表返回 schema.table/owner.table,避免 search_path 不一致导致 relation does not exist
  - 元数据/导入导出/提交变更统一解析限定名并正确引用
  - 前端查询与数据浏览支持限定名 quote
  - 单元格编辑态时间字段统一显示为 YYYY-MM-DD HH:mm:ss
  close #36
2026-02-03 14:26:37 +08:00
61 changed files with 12459 additions and 1190 deletions

View File

@@ -0,0 +1,58 @@
name: 问题反馈
description: 软件问题反馈
title: "[Bug] "
labels: ["bug"]
body:
- type: checkboxes
id: searched
attributes:
label: 已经搜索过 Issues未发现重复问题*
options:
- label: 我已经搜索过 Issues没有发现重复问题
validations:
required: true
- type: input
id: system
attributes:
label: 操作系统及版本
placeholder: Windows 10 22H2 / macOS Mojave / Linux
validations:
required: true
- type: input
id: version
attributes:
label: 软件安装版本
placeholder: v0.2.3
validations:
required: true
- type: textarea
id: description
attributes:
label: 问题简述及复现流程
description: 请详细描述你遇到的问题,并提供复现步骤
placeholder: |
1. 打开软件
2. 点击 xxx
3. 预期结果是 ...
4. 实际结果是 ...
5. 截图 ...
validations:
required: true
- type: textarea
id: extra
attributes:
label: 其他补充
description: 如果你有额外信息,请在此填写
placeholder: 可选
- type: checkboxes
id: pr
attributes:
label: 是否愿意提交 PR 修复当前 Issue
options:
- label: 我愿意尝试提交 PR

View File

@@ -0,0 +1,37 @@
name: 功能建议
description: 添加全新功能或改进现有功能
title: "[Enhancement] "
labels: ["enhancement"]
body:
- type: checkboxes
id: searched
attributes:
label: 已经搜索过 Issues未发现重复问题*
options:
- label: 我已经搜索过 Issues没有发现重复问题
validations:
required: true
- type: textarea
id: feature
attributes:
label: 功能描述
description: 请详细描述你希望添加或改进的功能
placeholder: 请描述你想要的功能
validations:
required: true
- type: textarea
id: extra
attributes:
label: 其他补充
description: 如果你有额外信息,请在此填写
placeholder: 可选
- type: checkboxes
id: pr
attributes:
label: 是否愿意提交 PR 实现当前 Issue
options:
- label: 我愿意尝试提交 PR

30
.github/ISSUE_TEMPLATE/03-generic.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: 其他反馈
description: 其他类型反馈、建议或讨论
title: "[Question] "
labels: ["question"]
body:
- type: checkboxes
id: searched
attributes:
label: 已经搜索过 Issues未发现重复问题*
options:
- label: 我已经搜索过 Issues没有发现重复问题
validations:
required: true
- type: textarea
id: content
attributes:
label: 内容
description: 请填写你的反馈、建议或讨论内容
placeholder: 请描述你的问题或想法
validations:
required: true
- type: textarea
id: extra
attributes:
label: 其他补充
description: 如果你有额外信息,请在此填写
placeholder: 可选

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1 @@
blank_issues_enabled: false

22
.github/workflows/release-winget.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: Publish to WinGet
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
release_tag:
required: true
description: 'Tag of release you want to publish'
type: string
jobs:
publish:
runs-on: windows-latest
steps:
- uses: vedantmgoyal9/winget-releaser@v2
with:
identifier: Syngnat.GoNavi
installers-regex: 'GoNavi-windows-(amd64|arm64)\.exe$'
release-tag: ${{ inputs.release_tag || github.ref_name }}
token: ${{ secrets.WINGET_TOKEN }}

View File

@@ -29,6 +29,13 @@ jobs:
platform: windows/amd64
artifact_name: GoNavi-windows-amd64
asset_ext: .exe
- os: windows-latest
platform: windows/arm64
artifact_name: GoNavi-windows-arm64
asset_ext: .exe
- os: ubuntu-22.04
platform: linux/amd64
artifact_name: GoNavi-linux-amd64
steps:
- name: Checkout code
@@ -45,6 +52,36 @@ jobs:
with:
node-version: '20'
# Linux Dependencies (GTK3, WebKit2GTK required by Wails)
- name: Install Linux Dependencies
if: contains(matrix.platform, 'linux')
run: |
sudo apt-get update
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev libfuse2
# Download linuxdeploy tools for AppImage packaging
LINUXDEPLOY_URL="https://github.com/linuxdeploy/linuxdeploy/releases/download/continuous/linuxdeploy-x86_64.AppImage"
PLUGIN_URL="https://github.com/linuxdeploy/linuxdeploy-plugin-gtk/releases/download/continuous/linuxdeploy-plugin-gtk-x86_64.AppImage"
echo "📥 下载 linuxdeploy..."
wget --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 --tries=3 \
-O /tmp/linuxdeploy "$LINUXDEPLOY_URL" || {
echo "⚠️ linuxdeploy 下载失败AppImage 打包将跳过"
touch /tmp/skip-appimage
}
echo "📥 下载 linuxdeploy-plugin-gtk..."
wget --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 --tries=3 \
-O /tmp/linuxdeploy-plugin-gtk "$PLUGIN_URL" || {
echo "⚠️ linuxdeploy-plugin-gtk 下载失败AppImage 打包将跳过"
touch /tmp/skip-appimage
}
if [ ! -f /tmp/skip-appimage ]; then
chmod +x /tmp/linuxdeploy /tmp/linuxdeploy-plugin-gtk
echo "✅ AppImage 工具准备完成"
fi
- name: Install Wails
run: go install -v github.com/wailsapp/wails/v2/cmd/wails@latest
@@ -107,12 +144,93 @@ jobs:
echo "📦 正在移动 $FINAL_EXE 到根目录..."
mv "$FINAL_EXE" "../../$FINAL_EXE"
# Linux Packaging (tar.gz and AppImage)
- name: Package Linux
if: contains(matrix.platform, 'linux')
run: |
cd build/bin
TARGET="${{ matrix.artifact_name }}"
if [ ! -f "$TARGET" ]; then
echo "❌ 未找到构建产物 '$TARGET'!"
exit 1
fi
chmod +x "$TARGET"
# 1. Create tar.gz
echo "📦 正在打包 $TARGET.tar.gz..."
tar -czvf "$TARGET.tar.gz" "$TARGET"
mv "$TARGET.tar.gz" ../../
# 2. Create AppImage (skip for ARM64 or if tools unavailable)
if [ -f /tmp/skip-appimage ]; then
echo "⚠️ 跳过 AppImage 打包"
exit 0
fi
echo "📦 正在生成 AppImage..."
# Create AppDir structure
mkdir -p AppDir/usr/bin
mkdir -p AppDir/usr/share/applications
mkdir -p AppDir/usr/share/icons/hicolor/256x256/apps
cp "$TARGET" AppDir/usr/bin/gonavi
# Create desktop file
printf '%s\n' \
'[Desktop Entry]' \
'Name=GoNavi' \
'Exec=gonavi' \
'Icon=gonavi' \
'Type=Application' \
'Categories=Development;Database;' \
'Comment=Database Management Tool' \
> AppDir/usr/share/applications/gonavi.desktop
cp AppDir/usr/share/applications/gonavi.desktop AppDir/gonavi.desktop
# Create a simple icon (or use existing if available)
if [ -f "../../build/appicon.png" ]; then
cp "../../build/appicon.png" AppDir/usr/share/icons/hicolor/256x256/apps/gonavi.png
cp "../../build/appicon.png" AppDir/gonavi.png
else
# Create a placeholder icon
convert -size 256x256 xc:#336791 -fill white -gravity center -pointsize 48 -annotate 0 "GoNavi" AppDir/gonavi.png || \
wget -q "https://via.placeholder.com/256/336791/FFFFFF?text=GoNavi" -O AppDir/gonavi.png || \
touch AppDir/gonavi.png
cp AppDir/gonavi.png AppDir/usr/share/icons/hicolor/256x256/apps/gonavi.png
fi
# Build AppImage
export DEPLOY_GTK_VERSION=3
/tmp/linuxdeploy --appdir AppDir --plugin gtk --output appimage || {
echo "⚠️ AppImage 生成失败,但 tar.gz 已成功生成"
exit 0
}
# Rename output
mv GoNavi*.AppImage "$TARGET.AppImage" 2>/dev/null || {
echo "⚠️ AppImage 重命名失败"
exit 0
}
if [ -f "$TARGET.AppImage" ]; then
mv "$TARGET.AppImage" ../../
echo "✅ AppImage 生成成功"
fi
# Upload to Actions Artifacts (Temporary Storage)
- name: Upload Artifact
uses: actions/upload-artifact@v4
with:
name: build-artifacts-${{ strategy.job-index }} # Unique name per job
path: GoNavi-*${{ matrix.asset_ext }}
path: |
GoNavi-*.dmg
GoNavi-*.exe
GoNavi-*.tar.gz
GoNavi-*.AppImage
retention-days: 1
# Phase 2: Collect all artifacts and Publish Release (Single Job)

View File

@@ -136,14 +136,121 @@ if command -v x86_64-w64-mingw32-gcc &> /dev/null; then
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-amd64.exe"
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-amd64.exe"
else
echo -e "${RED} ❌ Windows 构建失败。${NC}"
echo -e "${RED} ❌ Windows amd64 构建失败。${NC}"
fi
else
echo -e "${YELLOW} ⚠️ 未找到 MinGW 工具 (x86_64-w64-mingw32-gcc),跳过 Windows 构建。${NC}"
echo -e "${YELLOW} ⚠️ 未找到 MinGW 工具 (x86_64-w64-mingw32-gcc),跳过 Windows amd64 构建。${NC}"
fi
# --- Windows ARM64 构建 ---
echo -e "${GREEN}🪟 正在构建 Windows (arm64)...${NC}"
if command -v aarch64-w64-mingw32-gcc &> /dev/null; then
wails build -platform windows/arm64 -clean
if [ $? -eq 0 ]; then
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}.exe" "$DIST_DIR/${APP_NAME}-${VERSION}-windows-arm64.exe"
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-windows-arm64.exe"
else
echo -e "${RED} ❌ Windows arm64 构建失败。${NC}"
fi
else
echo -e "${YELLOW} ⚠️ 未找到 MinGW ARM64 工具 (aarch64-w64-mingw32-gcc),跳过 Windows arm64 构建。${NC}"
echo " 安装命令: brew install mingw-w64 (需要支持 ARM64 的版本)"
fi
# --- Linux AMD64 构建 ---
echo -e "${GREEN}🐧 正在构建 Linux (amd64)...${NC}"
# 检测当前系统
CURRENT_OS=$(uname -s)
CURRENT_ARCH=$(uname -m)
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "x86_64" ]; then
# 本机 Linux amd64直接构建
wails build -platform linux/amd64 -clean
if [ $? -eq 0 ]; then
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
# 打包为 tar.gz
cd "$DIST_DIR"
tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64"
rm "${APP_NAME}-${VERSION}-linux-amd64"
cd ..
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-amd64.tar.gz"
else
echo -e "${RED} ❌ Linux amd64 构建失败。${NC}"
fi
elif command -v x86_64-linux-gnu-gcc &> /dev/null; then
# macOS 或其他系统,尝试交叉编译
export CC=x86_64-linux-gnu-gcc
export CXX=x86_64-linux-gnu-g++
export CGO_ENABLED=1
wails build -platform linux/amd64 -clean
if [ $? -eq 0 ]; then
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-amd64"
cd "$DIST_DIR"
tar -czvf "${APP_NAME}-${VERSION}-linux-amd64.tar.gz" "${APP_NAME}-${VERSION}-linux-amd64"
rm "${APP_NAME}-${VERSION}-linux-amd64"
cd ..
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-amd64.tar.gz"
else
echo -e "${RED} ❌ Linux amd64 交叉编译失败。${NC}"
fi
unset CC CXX CGO_ENABLED
else
echo -e "${YELLOW} ⚠️ 非 Linux 系统且未找到交叉编译工具,跳过 Linux amd64 构建。${NC}"
echo " 在 Linux 上运行此脚本可直接构建,或安装交叉编译工具链。"
fi
# --- Linux ARM64 构建 ---
echo -e "${GREEN}🐧 正在构建 Linux (arm64)...${NC}"
if [ "$CURRENT_OS" = "Linux" ] && [ "$CURRENT_ARCH" = "aarch64" ]; then
# 本机 Linux arm64直接构建
wails build -platform linux/arm64 -clean
if [ $? -eq 0 ]; then
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
cd "$DIST_DIR"
tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64"
rm "${APP_NAME}-${VERSION}-linux-arm64"
cd ..
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-arm64.tar.gz"
else
echo -e "${RED} ❌ Linux arm64 构建失败。${NC}"
fi
elif command -v aarch64-linux-gnu-gcc &> /dev/null; then
# 交叉编译
export CC=aarch64-linux-gnu-gcc
export CXX=aarch64-linux-gnu-g++
export CGO_ENABLED=1
wails build -platform linux/arm64 -clean
if [ $? -eq 0 ]; then
mv "$BUILD_BIN_DIR/${DEFAULT_BINARY_NAME}" "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
chmod +x "$DIST_DIR/${APP_NAME}-${VERSION}-linux-arm64"
cd "$DIST_DIR"
tar -czvf "${APP_NAME}-${VERSION}-linux-arm64.tar.gz" "${APP_NAME}-${VERSION}-linux-arm64"
rm "${APP_NAME}-${VERSION}-linux-arm64"
cd ..
echo " ✅ 已生成 ${APP_NAME}-${VERSION}-linux-arm64.tar.gz"
else
echo -e "${RED} ❌ Linux arm64 交叉编译失败。${NC}"
fi
unset CC CXX CGO_ENABLED
else
echo -e "${YELLOW} ⚠️ 非 Linux ARM64 系统且未找到交叉编译工具,跳过 Linux arm64 构建。${NC}"
echo " 安装命令 (Ubuntu): sudo apt install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu"
echo " 安装命令 (macOS): brew install aarch64-linux-gnu-gcc (需要第三方 tap)"
fi
# 清理中间构建目录
rm -rf "build/bin"
echo ""
echo -e "${GREEN}🎉 所有任务完成!构建产物在 'dist/' 目录下:${NC}"
ls -1 "$DIST_DIR"
ls -lh "$DIST_DIR"
echo ""
echo -e "${GREEN}📋 支持的平台:${NC}"
echo " • macOS (Intel/Apple Silicon): .dmg"
echo " • Windows (x64/ARM64): .exe"
echo " • Linux (x64/ARM64): .tar.gz"
echo ""
echo -e "${YELLOW}💡 提示Linux AppImage 包请使用 GitHub Actions CI/CD 构建。${NC}"

View File

@@ -2,7 +2,7 @@
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<link rel="icon" type="image/svg+xml" href="/logo.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>GoNavi</title>
</head>
@@ -10,4 +10,4 @@
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>
</html>

52
frontend/public/logo.svg Normal file
View File

@@ -0,0 +1,52 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
<defs>
<!-- Background: Soft Light Grey -->
<linearGradient id="bgSoft" x1="0%" y1="0%" x2="0%" y2="100%">
<stop offset="0%" style="stop-color:#f5f7fa;stop-opacity:1" />
<stop offset="100%" style="stop-color:#c3cfe2;stop-opacity:1" />
</linearGradient>
<!-- Hexagon: Solid Tech Pink -->
<linearGradient id="solidPink" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" style="stop-color:#FF5F6D;stop-opacity:1" />
<stop offset="100%" style="stop-color:#FFC371;stop-opacity:1" />
</linearGradient>
<!-- N: Solid Tech Blue/Cyan -->
<linearGradient id="solidCyan" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" style="stop-color:#00c6ff;stop-opacity:1" />
<stop offset="100%" style="stop-color:#0072ff;stop-opacity:1" />
</linearGradient>
<filter id="hardShadow" x="-20%" y="-20%" width="140%" height="140%">
<feGaussianBlur in="SourceAlpha" stdDeviation="4"/>
<feOffset dx="4" dy="4" result="offsetblur"/>
<feComponentTransfer>
<feFuncA type="linear" slope="0.2"/>
</feComponentTransfer>
<feMerge>
<feMergeNode/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
</defs>
<!-- Background -->
<rect x="32" y="32" width="448" height="448" rx="100" fill="url(#bgSoft)" />
<!-- Main Content Centered -->
<g transform="translate(106, 106) scale(0.6)" filter="url(#hardShadow)">
<!-- Hex G -->
<path d="M 250 0 L 466 125 L 466 375 L 250 500 L 34 375 L 34 125 Z"
fill="none" stroke="url(#solidPink)" stroke-width="45" stroke-linejoin="round"/>
<!-- G Crossbar -->
<path d="M 466 300 L 330 300" stroke="url(#solidPink)" stroke-width="45" stroke-linecap="round"/>
<!-- Inner N -->
<path d="M 160 350 L 160 150 L 340 350 L 340 150"
fill="none" stroke="url(#solidCyan)" stroke-width="50" stroke-linecap="round" stroke-linejoin="round"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.0 KiB

View File

@@ -17,7 +17,14 @@ function App() {
const [isModalOpen, setIsModalOpen] = useState(false);
const [isSyncModalOpen, setIsSyncModalOpen] = useState(false);
const [editingConnection, setEditingConnection] = useState<SavedConnection | null>(null);
const { darkMode, toggleDarkMode, addTab, activeContext, connections, addConnection, tabs, activeTabId } = useStore();
const darkMode = useStore(state => state.darkMode);
const toggleDarkMode = useStore(state => state.toggleDarkMode);
const addTab = useStore(state => state.addTab);
const activeContext = useStore(state => state.activeContext);
const connections = useStore(state => state.connections);
const addConnection = useStore(state => state.addConnection);
const tabs = useStore(state => state.tabs);
const activeTabId = useStore(state => state.activeTabId);
const handleNewQuery = () => {
let connId = activeContext?.connectionId || '';
@@ -285,12 +292,12 @@ function App() {
title="拖动调整宽度"
/>
</Sider>
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<div style={{ flex: 1, overflow: 'hidden' }}>
<TabManager />
</div>
{isLogPanelOpen && (
<LogPanel
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<TabManager />
</div>
{isLogPanelOpen && (
<LogPanel
height={logPanelHeight}
onClose={() => setIsLogPanelOpen(false)}
onResizeStart={handleLogResizeStart}
@@ -343,4 +350,4 @@ function App() {
);
}
export default App;
export default App;

View File

@@ -1,8 +1,8 @@
import React, { useState, useEffect } from 'react';
import React, { useState, useEffect, useRef } from 'react';
import { Modal, Form, Input, InputNumber, Button, message, Checkbox, Divider, Select, Alert, Card, Row, Col, Typography, Collapse } from 'antd';
import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined } from '@ant-design/icons';
import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined, CloudOutlined } from '@ant-design/icons';
import { useStore } from '../store';
import { DBConnect, DBGetDatabases, TestConnection } from '../../wailsjs/go/app/App';
import { DBGetDatabases, TestConnection, RedisConnect } from '../../wailsjs/go/app/App';
import { SavedConnection } from '../types';
const { Meta } = Card;
@@ -16,6 +16,9 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
const [step, setStep] = useState(1); // 1: Select Type, 2: Configure
const [testResult, setTestResult] = useState<{ type: 'success' | 'error', message: string } | null>(null);
const [dbList, setDbList] = useState<string[]>([]);
const [redisDbList, setRedisDbList] = useState<number[]>([]); // Redis databases 0-15
const testInFlightRef = useRef(false);
const testTimerRef = useRef<number | null>(null);
const addConnection = useStore((state) => state.addConnection);
const updateConnection = useStore((state) => state.updateConnection);
@@ -23,6 +26,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
if (open) {
setTestResult(null); // Reset test result
setDbList([]);
setRedisDbList([]);
if (initialValues) {
// Edit mode: Go directly to step 2
setStep(2);
@@ -35,6 +39,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
password: initialValues.config.password,
database: initialValues.config.database,
includeDatabases: initialValues.includeDatabases,
includeRedisDatabases: initialValues.includeRedisDatabases,
useSSH: initialValues.config.useSSH,
sshHost: initialValues.config.ssh?.host,
sshPort: initialValues.config.ssh?.port,
@@ -47,6 +52,10 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
});
setUseSSH(initialValues.config.useSSH || false);
setDbType(initialValues.config.type);
// 如果是 Redis 编辑模式,设置已保存的 Redis 数据库列表
if (initialValues.config.type === 'redis') {
setRedisDbList(Array.from({ length: 16 }, (_, i) => i));
}
} else {
// Create mode: Start at step 1
setStep(1);
@@ -57,64 +66,94 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
}
}, [open, initialValues]);
useEffect(() => {
return () => {
if (testTimerRef.current !== null) {
window.clearTimeout(testTimerRef.current);
testTimerRef.current = null;
}
};
}, []);
const handleOk = async () => {
try {
const values = await form.validateFields();
setLoading(true);
const config = await buildConfig(values);
const res = await DBConnect(config as any);
setLoading(false);
if (res.success) {
const newConn = {
id: initialValues ? initialValues.id : Date.now().toString(),
name: values.name || (values.type === 'sqlite' ? 'SQLite DB' : values.host),
config: config,
includeDatabases: values.includeDatabases
};
if (initialValues) {
updateConnection(newConn);
message.success('连接已更新!');
} else {
addConnection(newConn);
message.success('连接已保存!');
}
form.resetFields();
setUseSSH(false);
setDbType('mysql');
setStep(1);
onClose();
const config = await buildConfig(values);
const isRedisType = values.type === 'redis';
const newConn = {
id: initialValues ? initialValues.id : Date.now().toString(),
name: values.name || (values.type === 'sqlite' ? 'SQLite DB' : (values.type === 'redis' ? `Redis ${values.host}` : values.host)),
config: config,
includeDatabases: values.includeDatabases,
includeRedisDatabases: isRedisType ? values.includeRedisDatabases : undefined
};
if (initialValues) {
updateConnection(newConn);
message.success('配置已更新(未连接)');
} else {
message.error('连接失败: ' + res.message);
addConnection(newConn);
message.success('配置已保存(未连接)');
}
setLoading(false);
form.resetFields();
setUseSSH(false);
setDbType('mysql');
setStep(1);
onClose();
} catch (e) {
setLoading(false);
}
};
const requestTest = () => {
if (loading) return;
if (testTimerRef.current !== null) return;
testTimerRef.current = window.setTimeout(() => {
testTimerRef.current = null;
handleTest();
}, 0);
};
const handleTest = async () => {
if (testInFlightRef.current) return;
testInFlightRef.current = true;
try {
const values = await form.validateFields();
setLoading(true);
setTestResult(null);
const config = await buildConfig(values);
const res = await TestConnection(config as any);
setLoading(false);
// Use different API for Redis
const isRedisType = values.type === 'redis';
const res = isRedisType
? await RedisConnect(config as any)
: await TestConnection(config as any);
if (res.success) {
setTestResult({ type: 'success', message: res.message });
const dbRes = await DBGetDatabases(config as any);
if (dbRes.success) {
const dbs = (dbRes.data as any[]).map((row: any) => row.Database || row.database);
setDbList(dbs);
if (isRedisType) {
// Redis: generate database list 0-15
setRedisDbList(Array.from({ length: 16 }, (_, i) => i));
} else {
// Other databases: fetch database list
const dbRes = await DBGetDatabases(config as any);
if (dbRes.success) {
const dbs = (dbRes.data as any[]).map((row: any) => row.Database || row.database);
setDbList(dbs);
}
}
} else {
setTestResult({ type: 'error', message: "测试失败: " + res.message });
}
} catch (e) {
// ignore
} finally {
testInFlightRef.current = false;
setLoading(false);
}
};
@@ -128,7 +167,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
keyPath: values.sshKeyPath || ""
} : { host: "", port: 22, user: "", password: "", keyPath: "" };
return {
return {
type: values.type,
host: values.host || "",
port: Number(values.port || 0),
@@ -146,12 +185,13 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
const handleTypeSelect = (type: string) => {
setDbType(type);
form.setFieldsValue({ type: type });
// Auto-fill default port
let defaultPort = 3306;
switch (type) {
case 'mysql': defaultPort = 3306; break;
case 'postgres': defaultPort = 5432; break;
case 'redis': defaultPort = 6379; break;
case 'oracle': defaultPort = 1521; break;
case 'dameng': defaultPort = 5236; break;
case 'kingbase': defaultPort = 54321; break;
@@ -166,10 +206,12 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
const isSqlite = dbType === 'sqlite';
const isCustom = dbType === 'custom';
const isRedis = dbType === 'redis';
const dbTypes = [
{ key: 'mysql', name: 'MySQL', icon: <ConsoleSqlOutlined style={{ fontSize: 24, color: '#00758F' }} /> },
{ key: 'postgres', name: 'PostgreSQL', icon: <DatabaseOutlined style={{ fontSize: 24, color: '#336791' }} /> },
{ key: 'redis', name: 'Redis', icon: <CloudOutlined style={{ fontSize: 24, color: '#DC382D' }} /> },
{ key: 'sqlite', name: 'SQLite', icon: <FileTextOutlined style={{ fontSize: 24, color: '#003B57' }} /> },
{ key: 'oracle', name: 'Oracle', icon: <DatabaseOutlined style={{ fontSize: 24, color: '#F80000' }} /> },
{ key: 'dameng', name: 'Dameng (达梦)', icon: <CloudServerOutlined style={{ fontSize: 24, color: '#1890ff' }} /> },
@@ -226,7 +268,10 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
<>
<div style={{ display: 'flex', gap: 16 }}>
<Form.Item name="host" label={isSqlite ? "文件路径 (绝对路径)" : "主机地址 (Host)"} rules={[{ required: true, message: '请输入地址/路径' }]} style={{ flex: 1 }}>
<Input placeholder={isSqlite ? "/path/to/db.sqlite" : "localhost"} />
<Input
placeholder={isSqlite ? "/path/to/db.sqlite" : "localhost"}
onDoubleClick={requestTest}
/>
</Form.Item>
{!isSqlite && (
<Form.Item name="port" label="端口 (Port)" rules={[{ required: true, message: '请输入端口号' }]} style={{ width: 100 }}>
@@ -235,7 +280,22 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
)}
</div>
{!isSqlite && (
{/* Redis specific: password only, no username */}
{isRedis && (
<>
<Form.Item name="password" label="密码 (可选)">
<Input.Password placeholder="Redis 密码(如果设置了 requirepass" />
</Form.Item>
<Form.Item name="includeRedisDatabases" label="显示数据库 (留空显示全部)" help="连接测试成功后可选择">
<Select mode="multiple" placeholder="选择显示的数据库 (0-15)" allowClear>
{redisDbList.map(db => <Select.Option key={db} value={db}>db{db}</Select.Option>)}
</Select>
</Form.Item>
</>
)}
{/* Non-Redis, non-SQLite: username and password */}
{!isSqlite && !isRedis && (
<div style={{ display: 'flex', gap: 16 }}>
<Form.Item name="user" label="用户名" rules={[{ required: true, message: '请输入用户名' }]} style={{ flex: 1 }}>
<Input />
@@ -245,8 +305,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
</Form.Item>
</div>
)}
{!isSqlite && (
{!isSqlite && !isRedis && (
<Form.Item name="includeDatabases" label="显示数据库 (留空显示全部)" help="连接测试成功后可选择">
<Select mode="multiple" placeholder="选择显示的数据库" allowClear>
{dbList.map(db => <Select.Option key={db} value={db}>{db}</Select.Option>)}
@@ -264,8 +324,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
{useSSH && (
<div style={{ padding: '12px', background: '#f5f5f5', borderRadius: 6, marginTop: 12 }}>
<div style={{ display: 'flex', gap: 16 }}>
<Form.Item name="sshHost" label="SSH 主机" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
<Input placeholder="ssh.example.com" />
<Form.Item name="sshHost" label="SSH 主机 (域名或IP)" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
<Input placeholder="例如: ssh.example.com 或 192.168.1.100" />
</Form.Item>
<Form.Item name="sshPort" label="端口" rules={[{ required: useSSH, message: '请输入SSH端口' }]} style={{ width: 100 }}>
<InputNumber style={{ width: '100%' }} />
@@ -328,7 +388,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
}
return [
!initialValues && <Button key="back" onClick={() => setStep(1)} style={{ float: 'left' }}></Button>,
<Button key="test" loading={loading} onClick={handleTest}></Button>,
<Button key="test" loading={loading} onClick={requestTest}></Button>,
<Button key="cancel" onClick={onClose}></Button>,
<Button key="submit" type="primary" loading={loading} onClick={handleOk}></Button>
];

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,36 @@
import React, { useState, useEffect } from 'react';
import { Modal, Form, Select, Button, message, Steps, Transfer, Card, Alert, Divider, Typography } from 'antd';
import React, { useState, useEffect, useRef } from 'react';
import { Modal, Form, Select, Button, message, Steps, Transfer, Card, Alert, Divider, Typography, Progress, Checkbox, Table, Drawer, Tabs } from 'antd';
import { useStore } from '../store';
import { DBGetDatabases, DBGetTables, DataSync } from '../../wailsjs/go/app/App';
import { DBGetDatabases, DBGetTables, DataSync, DataSyncAnalyze, DataSyncPreview } from '../../wailsjs/go/app/App';
import { SavedConnection } from '../types';
import { connection } from '../../wailsjs/go/models';
import { EventsOn } from '../../wailsjs/runtime/runtime';
const { Title, Text } = Typography;
const { Step } = Steps;
const { Option } = Select;
type SyncLogEvent = { jobId: string; level?: string; message?: string; ts?: number };
type SyncProgressEvent = { jobId: string; percent?: number; current?: number; total?: number; table?: string; stage?: string };
type SyncLogItem = { level: string; message: string; ts?: number };
type TableDiffSummary = {
table: string;
pkColumn?: string;
canSync?: boolean;
inserts?: number;
updates?: number;
deletes?: number;
same?: number;
message?: string;
};
type TableOps = {
insert: boolean;
update: boolean;
delete: boolean;
selectedInsertPks?: string[];
selectedUpdatePks?: string[];
selectedDeletePks?: string[];
};
const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, onClose }) => {
const connections = useStore((state) => state.connections);
const [currentStep, setCurrentStep] = useState(0);
@@ -27,8 +49,76 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
const [allTables, setAllTables] = useState<string[]>([]);
const [selectedTables, setSelectedTables] = useState<string[]>([]);
// Options
const [syncContent, setSyncContent] = useState<'data' | 'schema' | 'both'>('data');
const [syncMode, setSyncMode] = useState<string>('insert_update');
const [autoAddColumns, setAutoAddColumns] = useState<boolean>(true);
const [showSameTables, setShowSameTables] = useState<boolean>(false);
const [analyzing, setAnalyzing] = useState<boolean>(false);
const [diffTables, setDiffTables] = useState<TableDiffSummary[]>([]);
const [tableOptions, setTableOptions] = useState<Record<string, TableOps>>({});
const [previewOpen, setPreviewOpen] = useState(false);
const [previewTable, setPreviewTable] = useState<string>('');
const [previewLoading, setPreviewLoading] = useState(false);
const [previewData, setPreviewData] = useState<any>(null);
// Step 3: Result
const [syncResult, setSyncResult] = useState<any>(null);
const [syncing, setSyncing] = useState(false);
const [syncLogs, setSyncLogs] = useState<SyncLogItem[]>([]);
const [syncProgress, setSyncProgress] = useState<{ percent: number; current: number; total: number; table: string; stage: string }>({
percent: 0,
current: 0,
total: 0,
table: '',
stage: ''
});
const jobIdRef = useRef<string>('');
const logBoxRef = useRef<HTMLDivElement>(null);
const autoScrollRef = useRef(true);
const normalizeConnConfig = (conn: SavedConnection, database?: string) => ({
...conn.config,
port: Number((conn.config as any).port),
password: conn.config.password || "",
useSSH: conn.config.useSSH || false,
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
database: typeof database === 'string' ? database : (conn.config.database || ""),
});
useEffect(() => {
if (!open) return;
const offLog = EventsOn('sync:log', (event: SyncLogEvent) => {
if (!event || event.jobId !== jobIdRef.current) return;
const msg = String(event.message || '').trim();
if (!msg) return;
setSyncLogs(prev => [...prev, { level: String(event.level || 'info'), message: msg, ts: event.ts }]);
});
const offProgress = EventsOn('sync:progress', (event: SyncProgressEvent) => {
if (!event || event.jobId !== jobIdRef.current) return;
setSyncProgress(prev => ({
percent: typeof event.percent === 'number' ? event.percent : prev.percent,
current: typeof event.current === 'number' ? event.current : prev.current,
total: typeof event.total === 'number' ? event.total : prev.total,
table: typeof event.table === 'string' ? event.table : prev.table,
stage: typeof event.stage === 'string' ? event.stage : prev.stage,
}));
});
return () => {
offLog();
offProgress();
};
}, [open]);
useEffect(() => {
if (!logBoxRef.current) return;
if (!autoScrollRef.current) return;
logBoxRef.current.scrollTop = logBoxRef.current.scrollHeight;
}, [syncLogs]);
useEffect(() => {
if (open) {
@@ -38,7 +128,23 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
setSourceDb('');
setTargetDb('');
setSelectedTables([]);
setSyncContent('data');
setSyncMode('insert_update');
setAutoAddColumns(true);
setShowSameTables(false);
setAnalyzing(false);
setDiffTables([]);
setTableOptions({});
setPreviewOpen(false);
setPreviewTable('');
setPreviewLoading(false);
setPreviewData(null);
setSyncResult(null);
setSyncing(false);
setSyncLogs([]);
setSyncProgress({ percent: 0, current: 0, total: 0, table: '', stage: '' });
jobIdRef.current = '';
autoScrollRef.current = true;
}
}, [open]);
@@ -49,7 +155,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
if (conn) {
setLoading(true);
try {
const res = await DBGetDatabases(conn.config as any);
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
if (res.success) {
setSourceDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username));
}
@@ -65,7 +171,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
if (conn) {
setLoading(true);
try {
const res = await DBGetDatabases(conn.config as any);
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
if (res.success) {
setTargetDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username));
}
@@ -83,7 +189,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
try {
const conn = connections.find(c => c.id === sourceConnId);
if (conn) {
const config = { ...conn.config, database: sourceDb };
const config = normalizeConnConfig(conn, sourceDb);
const res = await DBGetTables(config as any, sourceDb);
if (res.success) {
// DBGetTables returns [{Table: "name"}, ...]
@@ -98,36 +204,221 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
setLoading(false);
};
const runSync = async () => {
const updateTableOption = (table: string, key: keyof TableOps, value: any) => {
setTableOptions(prev => ({
...prev,
[table]: { ...(prev[table] || { insert: true, update: true, delete: false }), [key]: value }
}));
};
const analyzeDiff = async () => {
if (selectedTables.length === 0) return;
if (!sourceConnId || !targetConnId) return message.error("Select connections first");
if (!sourceDb || !targetDb) return message.error("Select databases first");
setLoading(true);
setAnalyzing(true);
setDiffTables([]);
setTableOptions({});
setSyncLogs([]);
const sConn = connections.find(c => c.id === sourceConnId)!;
const tConn = connections.find(c => c.id === targetConnId)!;
const jobId = `analyze-${Date.now()}-${Math.random().toString(16).slice(2, 8)}`;
jobIdRef.current = jobId;
autoScrollRef.current = true;
setSyncProgress({ percent: 0, current: 0, total: selectedTables.length, table: '', stage: '差异分析' });
const config = {
sourceConfig: normalizeConnConfig(sConn, sourceDb),
targetConfig: normalizeConnConfig(tConn, targetDb),
tables: selectedTables,
content: syncContent,
mode: "insert_update",
autoAddColumns,
jobId,
};
try {
const res = await DataSyncAnalyze(config as any);
if (res.success) {
const tables = ((res.data as any)?.tables || []) as TableDiffSummary[];
setDiffTables(tables);
const init: Record<string, TableOps> = {};
tables.forEach(t => {
const can = !!t.canSync;
init[t.table] = {
insert: can,
update: can,
delete: false,
selectedInsertPks: [],
selectedUpdatePks: [],
selectedDeletePks: [],
};
});
setTableOptions(init);
message.success("差异分析完成");
} else {
message.error(res.message || "差异分析失败");
}
} catch (e: any) {
message.error("差异分析失败: " + (e?.message || ""));
}
setLoading(false);
setAnalyzing(false);
};
const openPreview = async (table: string) => {
if (!table) return;
const sConn = connections.find(c => c.id === sourceConnId)!;
const tConn = connections.find(c => c.id === targetConnId)!;
setPreviewOpen(true);
setPreviewTable(table);
setPreviewLoading(true);
setPreviewData(null);
const config = {
sourceConfig: normalizeConnConfig(sConn, sourceDb),
targetConfig: normalizeConnConfig(tConn, targetDb),
tables: selectedTables,
content: "data",
mode: "insert_update",
autoAddColumns,
};
try {
const res = await DataSyncPreview(config as any, table, 200);
if (res.success) {
setPreviewData(res.data);
} else {
message.error(res.message || "加载差异预览失败");
}
} catch (e: any) {
message.error("加载差异预览失败: " + (e?.message || ""));
}
setPreviewLoading(false);
};
const runSync = async () => {
if (syncContent !== 'schema' && diffTables.length === 0) {
message.error("请先对比差异,再开始同步");
return;
}
if (syncContent !== 'schema' && syncMode === 'full_overwrite') {
const ok = await new Promise<boolean>((resolve) => {
Modal.confirm({
title: '确认全量覆盖',
content: '全量覆盖会清空目标表数据后再插入,请确认已备份目标库。',
okText: '继续执行',
cancelText: '取消',
onOk: () => resolve(true),
onCancel: () => resolve(false),
});
});
if (!ok) return;
}
setLoading(true);
setSyncing(true);
setCurrentStep(2);
setSyncResult(null);
setSyncLogs([]);
const sConn = connections.find(c => c.id === sourceConnId)!;
const tConn = connections.find(c => c.id === targetConnId)!;
const jobId = `sync-${Date.now()}-${Math.random().toString(16).slice(2, 8)}`;
jobIdRef.current = jobId;
autoScrollRef.current = true;
setSyncProgress({
percent: 0,
current: 0,
total: selectedTables.length,
table: '',
stage: '准备开始',
});
const config = {
sourceConfig: { ...sConn.config, database: sourceDb },
targetConfig: { ...tConn.config, database: targetDb },
sourceConfig: {
...sConn.config,
port: Number((sConn.config as any).port),
password: sConn.config.password || "",
useSSH: sConn.config.useSSH || false,
ssh: sConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
database: sourceDb,
},
targetConfig: {
...tConn.config,
port: Number((tConn.config as any).port),
password: tConn.config.password || "",
useSSH: tConn.config.useSSH || false,
ssh: tConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
database: targetDb,
},
tables: selectedTables,
mode: "insert_update"
content: syncContent,
mode: syncMode,
autoAddColumns,
tableOptions,
jobId,
};
try {
const res = await DataSync(config as any);
setSyncResult(res);
setCurrentStep(2);
if (Array.isArray(res?.logs) && res.logs.length > 0) {
setSyncLogs(prev => {
if (prev.length > 0) return prev;
return (res.logs as string[]).map((log) => {
const msg = String(log || '').trim();
if (msg.includes('致命错误') || msg.includes('失败')) return { level: 'error', message: msg };
if (msg.includes('跳过') || msg.includes('警告')) return { level: 'warn', message: msg };
return { level: 'info', message: msg };
});
});
}
} catch (e) {
message.error("Sync execution failed");
setSyncResult({ success: false, message: "同步执行失败", logs: [] });
}
setLoading(false);
setSyncing(false);
};
const renderSyncLogItem = (item: SyncLogItem) => {
const level = String(item.level || 'info').toLowerCase();
const color = level === 'error' ? '#ff4d4f' : (level === 'warn' ? '#faad14' : '#595959');
const label = level === 'error' ? '错误' : (level === 'warn' ? '警告' : '信息');
const timeText = typeof item.ts === 'number' ? new Date(item.ts).toLocaleTimeString('zh-CN', { hour12: false }) : '';
return (
<div style={{ display: 'flex', gap: 8, alignItems: 'flex-start' }}>
<span style={{ color, flex: '0 0 auto' }}> {label}</span>
{timeText && <span style={{ color: '#8c8c8c', flex: '0 0 auto' }}>{timeText}</span>}
<span style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>{item.message}</span>
</div>
);
};
return (
<>
<Modal
title="数据同步"
open={open}
onCancel={onClose}
width={800}
footer={null}
destroyOnHidden
title="数据同步"
open={open}
onCancel={() => {
if (syncing) {
message.warning("同步执行中,暂不支持关闭");
return;
}
onClose();
}}
width={800}
footer={null}
destroyOnHidden
closable={!syncing}
maskClosable={!syncing}
>
<Steps current={currentStep} style={{ marginBottom: 24 }}>
<Step title="配置源与目标" />
@@ -137,34 +428,67 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
{/* STEP 1: CONFIG */}
{currentStep === 0 && (
<div style={{ display: 'flex', gap: 24, justifyContent: 'center' }}>
<Card title="源数据库" style={{ width: 350 }}>
<div>
<div style={{ display: 'flex', gap: 24, justifyContent: 'center' }}>
<Card title="源数据库" style={{ width: 350 }}>
<Form layout="vertical">
<Form.Item label="连接">
<Select value={sourceConnId} onChange={handleSourceConnChange}>
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
</Select>
</Form.Item>
<Form.Item label="数据库">
<Select value={sourceDb} onChange={setSourceDb} showSearch>
{sourceDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
</Select>
</Form.Item>
</Form>
</Card>
<div style={{ display: 'flex', alignItems: 'center' }}></div>
<Card title="目标数据库" style={{ width: 350 }}>
<Form layout="vertical">
<Form.Item label="连接">
<Select value={targetConnId} onChange={handleTargetConnChange}>
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
</Select>
</Form.Item>
<Form.Item label="数据库">
<Select value={targetDb} onChange={setTargetDb} showSearch>
{targetDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
</Select>
</Form.Item>
</Form>
</Card>
</div>
<Card title="同步选项" style={{ marginTop: 16 }}>
<Form layout="vertical">
<Form.Item label="连接">
<Select value={sourceConnId} onChange={handleSourceConnChange}>
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
<Form.Item label="同步内容">
<Select value={syncContent} onChange={setSyncContent}>
<Option value="data"></Option>
<Option value="schema"></Option>
<Option value="both"> + </Option>
</Select>
</Form.Item>
<Form.Item label="数据库">
<Select value={sourceDb} onChange={setSourceDb} showSearch>
{sourceDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
<Form.Item label="同步模式">
<Select value={syncMode} onChange={setSyncMode} disabled={syncContent === 'schema'}>
<Option value="insert_update">//</Option>
<Option value="insert_only"></Option>
<Option value="full_overwrite"></Option>
</Select>
</Form.Item>
</Form>
</Card>
<div style={{ display: 'flex', alignItems: 'center' }}></div>
<Card title="目标数据库" style={{ width: 350 }}>
<Form layout="vertical">
<Form.Item label="连接">
<Select value={targetConnId} onChange={handleTargetConnChange}>
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
</Select>
</Form.Item>
<Form.Item label="数据库">
<Select value={targetDb} onChange={setTargetDb} showSearch>
{targetDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
</Select>
<Form.Item>
<Checkbox checked={autoAddColumns} onChange={(e) => setAutoAddColumns(e.target.checked)}>
MySQL
</Checkbox>
</Form.Item>
{syncContent !== 'schema' && syncMode === 'full_overwrite' && (
<Alert
type="warning"
showIcon
message="全量覆盖会清空目标表数据,请谨慎使用。"
/>
)}
</Form>
</Card>
</div>
@@ -172,32 +496,155 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
{/* STEP 2: TABLES */}
{currentStep === 1 && (
<div style={{ height: 400 }}>
<Text type="secondary">:</Text>
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
<Text type="secondary">:</Text>
<Checkbox checked={showSameTables} onChange={(e) => setShowSameTables(e.target.checked)}>
</Checkbox>
</div>
<Transfer
dataSource={allTables.map(t => ({ key: t, title: t }))}
titles={['源表', '已选表']}
targetKeys={selectedTables}
onChange={(keys) => setSelectedTables(keys as string[])}
render={item => item.title}
listStyle={{ width: 350, height: 350, marginTop: 12 }}
listStyle={{ width: 350, height: 280, marginTop: 0 }}
locale={{ itemUnit: '项', itemsUnit: '项', searchPlaceholder: '搜索表', notFoundContent: '暂无数据' }}
/>
{diffTables.length > 0 && (
<div>
<Divider orientation="left"></Divider>
<Table
size="small"
pagination={false}
rowKey={(r: any) => r.table}
dataSource={diffTables.filter(t => {
const ins = Number(t.inserts || 0);
const upd = Number(t.updates || 0);
const del = Number(t.deletes || 0);
const same = Number(t.same || 0);
const msg = String(t.message || '').trim();
const can = !!t.canSync;
if (showSameTables) return true;
if (!can) return true;
if (msg) return true;
return ins > 0 || upd > 0 || del > 0 || same === 0;
})}
columns={[
{ title: '表名', dataIndex: 'table', key: 'table', ellipsis: true },
{
title: '插入',
key: 'inserts',
width: 90,
render: (_: any, r: any) => {
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
const disabled = !r.canSync || analyzing || Number(r.inserts || 0) === 0;
return (
<Checkbox
checked={!!ops.insert}
disabled={disabled}
onChange={(e) => updateTableOption(r.table, 'insert', e.target.checked)}
>
{Number(r.inserts || 0)}
</Checkbox>
);
}
},
{
title: '更新',
key: 'updates',
width: 90,
render: (_: any, r: any) => {
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
const disabled = !r.canSync || analyzing || Number(r.updates || 0) === 0;
return (
<Checkbox
checked={!!ops.update}
disabled={disabled}
onChange={(e) => updateTableOption(r.table, 'update', e.target.checked)}
>
{Number(r.updates || 0)}
</Checkbox>
);
}
},
{
title: '删除',
key: 'deletes',
width: 90,
render: (_: any, r: any) => {
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
const disabled = !r.canSync || analyzing || Number(r.deletes || 0) === 0;
return (
<Checkbox
checked={!!ops.delete}
disabled={disabled}
onChange={(e) => updateTableOption(r.table, 'delete', e.target.checked)}
>
{Number(r.deletes || 0)}
</Checkbox>
);
}
},
{ title: '相同', dataIndex: 'same', key: 'same', width: 70, render: (v: any) => Number(v || 0) },
{ title: '消息', dataIndex: 'message', key: 'message', ellipsis: true, render: (v: any) => (v ? String(v) : '') },
{
title: '预览',
key: 'preview',
width: 80,
render: (_: any, r: any) => {
const can = !!r.canSync;
const hasDiff = Number(r.inserts || 0) + Number(r.updates || 0) + Number(r.deletes || 0) > 0;
return (
<Button size="small" disabled={!can || !hasDiff || analyzing} onClick={() => openPreview(r.table)}>
</Button>
);
}
}
]}
/>
</div>
)}
</div>
)}
{/* STEP 3: RESULT */}
{currentStep === 2 && syncResult && (
{currentStep === 2 && (
<div>
<Alert
message={syncResult.success ? "同步完成" : "同步失败"}
description={syncResult.message || `成功同步 ${syncResult.tablesSynced} 张表. 插入: ${syncResult.rowsInserted}, 更新: ${syncResult.rowsUpdated}`}
type={syncResult.success ? "success" : "error"}
showIcon
<Alert
message={syncing ? "正在同步" : (syncResult?.success ? "同步完成" : "同步失败")}
description={
syncing
? `当前阶段:${syncProgress.stage || '执行中'}${syncProgress.table ? `,表:${syncProgress.table}` : ''}`
: (syncResult?.message || `成功同步 ${syncResult?.tablesSynced || 0} 张表. 插入: ${syncResult?.rowsInserted || 0}, 更新: ${syncResult?.rowsUpdated || 0}`)
}
type={syncing ? "info" : (syncResult?.success ? "success" : "error")}
showIcon
/>
<div style={{ marginTop: 12 }}>
<Progress
percent={syncProgress.percent}
status={syncing ? "active" : (syncResult?.success ? "success" : "exception")}
format={() => `${syncProgress.current}/${syncProgress.total}`}
/>
</div>
<Divider orientation="left"></Divider>
<div style={{ background: '#f5f5f5', padding: 12, height: 300, overflowY: 'auto', fontFamily: 'monospace' }}>
{syncResult.logs.map((log: string, i: number) => <div key={i}>{log}</div>)}
<div
ref={logBoxRef}
onScroll={() => {
const el = logBoxRef.current;
if (!el) return;
const nearBottom = el.scrollHeight - el.scrollTop - el.clientHeight < 40;
autoScrollRef.current = nearBottom;
}}
style={{ background: '#f5f5f5', padding: 12, height: 300, overflowY: 'auto', fontFamily: 'monospace' }}
>
{syncLogs.map((item, i: number) => <div key={i}>{renderSyncLogItem(item)}</div>)}
</div>
</div>
)}
@@ -206,20 +653,154 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
{currentStep === 0 && (
<Button type="primary" onClick={nextToTables} loading={loading}></Button>
)}
{currentStep === 1 && (
<>
<Button onClick={() => setCurrentStep(0)} style={{ marginRight: 8 }}></Button>
<Button type="primary" onClick={runSync} loading={loading} disabled={selectedTables.length === 0}></Button>
{currentStep === 1 && (
<>
<Button onClick={() => setCurrentStep(0)} style={{ marginRight: 8 }}></Button>
<Button onClick={analyzeDiff} loading={loading} disabled={syncContent === 'schema' || selectedTables.length === 0 || analyzing} style={{ marginRight: 8 }}>
</Button>
<Button
type="primary"
onClick={runSync}
loading={loading}
disabled={selectedTables.length === 0 || (syncContent !== 'schema' && diffTables.length === 0)}
>
</Button>
</>
)}
{currentStep === 2 && (
<>
<Button onClick={() => setCurrentStep(1)} style={{ marginRight: 8 }}></Button>
<Button type="primary" onClick={onClose}></Button>
<Button disabled={syncing} onClick={() => setCurrentStep(1)} style={{ marginRight: 8 }}></Button>
<Button type="primary" disabled={syncing} onClick={onClose}></Button>
</>
)}
</div>
</Modal>
<Drawer
title={`差异预览:${previewTable}`}
open={previewOpen}
onClose={() => { setPreviewOpen(false); setPreviewTable(''); setPreviewData(null); }}
width={900}
>
{previewLoading && <Alert type="info" showIcon message="正在加载差异预览..." />}
{!previewLoading && previewData && (
<div>
<Alert
type="info"
showIcon
message={`插入 ${previewData.totalInserts || 0},更新 ${previewData.totalUpdates || 0},删除 ${previewData.totalDeletes || 0}(预览最多展示 200 条/类型)`}
/>
<Divider />
<Tabs
items={[
{
key: 'insert',
label: `插入(${previewData.totalInserts || 0})`,
children: (
<div>
<Text type="secondary"></Text>
<Table
size="small"
style={{ marginTop: 8 }}
rowKey={(r: any) => r.pk}
dataSource={(previewData.inserts || []).map((r: any) => ({ ...r, key: r.pk }))}
pagination={false}
rowSelection={{
selectedRowKeys: (tableOptions[previewTable]?.selectedInsertPks || []) as any,
onChange: (keys) => updateTableOption(previewTable, 'selectedInsertPks', keys as string[]),
getCheckboxProps: () => ({ disabled: !tableOptions[previewTable]?.insert }),
}}
columns={[
{ title: previewData.pkColumn || '主键', dataIndex: 'pk', key: 'pk', width: 200, ellipsis: true },
{ title: '数据', dataIndex: 'row', key: 'row', render: (v: any) => <pre style={{ margin: 0, maxHeight: 140, overflow: 'auto' }}>{JSON.stringify(v, null, 2)}</pre> }
]}
/>
</div>
)
},
{
key: 'update',
label: `更新(${previewData.totalUpdates || 0})`,
children: (
<div>
<Text type="secondary"></Text>
<Table
size="small"
style={{ marginTop: 8 }}
rowKey={(r: any) => r.pk}
dataSource={(previewData.updates || []).map((r: any) => ({ ...r, key: r.pk }))}
pagination={false}
rowSelection={{
selectedRowKeys: (tableOptions[previewTable]?.selectedUpdatePks || []) as any,
onChange: (keys) => updateTableOption(previewTable, 'selectedUpdatePks', keys as string[]),
getCheckboxProps: () => ({ disabled: !tableOptions[previewTable]?.update }),
}}
columns={[
{ title: previewData.pkColumn || '主键', dataIndex: 'pk', key: 'pk', width: 200, ellipsis: true },
{ title: '变更字段', dataIndex: 'changedColumns', key: 'changedColumns', render: (v: any) => Array.isArray(v) ? v.join(', ') : '' },
{
title: '详情',
key: 'detail',
width: 80,
render: (_: any, r: any) => (
<Button size="small" onClick={() => {
Modal.info({
title: `更新详情:${previewTable} / ${r.pk}`,
width: 900,
content: (
<div style={{ display: 'flex', gap: 12 }}>
<div style={{ flex: 1 }}>
<Title level={5}></Title>
<pre style={{ maxHeight: 360, overflow: 'auto', background: '#f5f5f5', padding: 8 }}>{JSON.stringify(r.source, null, 2)}</pre>
</div>
<div style={{ flex: 1 }}>
<Title level={5}></Title>
<pre style={{ maxHeight: 360, overflow: 'auto', background: '#f5f5f5', padding: 8 }}>{JSON.stringify(r.target, null, 2)}</pre>
</div>
</div>
)
});
}}></Button>
)
}
]}
/>
</div>
)
},
{
key: 'delete',
label: `删除(${previewData.totalDeletes || 0})`,
children: (
<div>
<Alert type="warning" showIcon message="删除默认不勾选。请确认业务允许后再开启删除操作。" />
<Text type="secondary"></Text>
<Table
size="small"
style={{ marginTop: 8 }}
rowKey={(r: any) => r.pk}
dataSource={(previewData.deletes || []).map((r: any) => ({ ...r, key: r.pk }))}
pagination={false}
rowSelection={{
selectedRowKeys: (tableOptions[previewTable]?.selectedDeletePks || []) as any,
onChange: (keys) => updateTableOption(previewTable, 'selectedDeletePks', keys as string[]),
getCheckboxProps: () => ({ disabled: !tableOptions[previewTable]?.delete }),
}}
columns={[
{ title: previewData.pkColumn || '主键', dataIndex: 'pk', key: 'pk', width: 200, ellipsis: true },
{ title: '数据', dataIndex: 'row', key: 'row', render: (v: any) => <pre style={{ margin: 0, maxHeight: 140, overflow: 'auto' }}>{JSON.stringify(v, null, 2)}</pre> }
]}
/>
</div>
)
}
]}
/>
</div>
)}
</Drawer>
</>
);
};

View File

@@ -1,21 +1,29 @@
import React, { useEffect, useState, useCallback } from 'react';
import React, { useEffect, useState, useCallback, useRef } from 'react';
import { message } from 'antd';
import { TabData, ColumnDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
import DataGrid from './DataGrid';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildWhereSQL, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [data, setData] = useState<any[]>([]);
const [columnNames, setColumnNames] = useState<string[]>([]);
const [pkColumns, setPkColumns] = useState<string[]>([]);
const [loading, setLoading] = useState(false);
const { connections, addSqlLog } = useStore();
const connections = useStore(state => state.connections);
const addSqlLog = useStore(state => state.addSqlLog);
const fetchSeqRef = useRef(0);
const countSeqRef = useRef(0);
const countKeyRef = useRef<string>('');
const pkSeqRef = useRef(0);
const pkKeyRef = useRef<string>('');
const [pagination, setPagination] = useState({
current: 1,
pageSize: 100,
total: 0
total: 0,
totalKnown: false
});
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
@@ -23,12 +31,20 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [showFilter, setShowFilter] = useState(false);
const [filterConditions, setFilterConditions] = useState<any[]>([]);
useEffect(() => {
setPkColumns([]);
pkKeyRef.current = '';
countKeyRef.current = '';
setPagination(prev => ({ ...prev, current: 1, total: 0, totalKnown: false }));
}, [tab.connectionId, tab.dbName, tab.tableName]);
const fetchData = useCallback(async (page = pagination.current, size = pagination.pageSize) => {
const seq = ++fetchSeqRef.current;
setLoading(true);
const conn = connections.find(c => c.id === tab.connectionId);
if (!conn) {
message.error("Connection not found");
setLoading(false);
if (fetchSeqRef.current === seq) setLoading(false);
return;
}
@@ -41,61 +57,31 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
};
const quoteIdent = (ident: string) => {
if (!ident) return ident;
if (config.type === 'mysql') return `\`${ident.replace(/`/g, '``')}\``;
return `"${ident.replace(/"/g, '""')}"`;
};
const escapeLiteral = (val: string) => val.replace(/'/g, "''");
const dbType = config.type || '';
const dbName = tab.dbName || '';
const tableName = tab.tableName || '';
const whereParts: string[] = [];
filterConditions.forEach(cond => {
if (cond.column && cond.value) {
if (cond.op === 'LIKE') {
whereParts.push(`${quoteIdent(cond.column)} LIKE '%${escapeLiteral(cond.value)}%'`);
} else {
whereParts.push(`${quoteIdent(cond.column)} ${cond.op} '${escapeLiteral(cond.value)}'`);
}
}
});
const whereSQL = whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : "";
const whereSQL = buildWhereSQL(dbType, filterConditions);
const countSql = `SELECT COUNT(*) as total FROM ${quoteIdent(tableName)} ${whereSQL}`;
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
let sql = `SELECT * FROM ${quoteIdent(tableName)} ${whereSQL}`;
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
if (sortInfo && sortInfo.order) {
sql += ` ORDER BY ${quoteIdent(sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
sql += ` ORDER BY ${quoteIdentPart(dbType, sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
}
const offset = (page - 1) * size;
sql += ` LIMIT ${size} OFFSET ${offset}`;
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
const startTime = Date.now();
try {
const pCount = DBQuery(config as any, dbName, countSql);
const pData = DBQuery(config as any, dbName, sql);
let pCols = null;
if (pkColumns.length === 0) {
pCols = DBGetColumns(config as any, dbName, tableName);
}
const [resCount, resData] = await Promise.all([pCount, pData]);
const resData = await pData;
const duration = Date.now() - startTime;
// Log Execution
addSqlLog({
id: `log-${Date.now()}-count`,
timestamp: Date.now(),
sql: countSql,
status: resCount.success ? 'success' : 'error',
duration: duration / 2, // Estimate
message: resCount.success ? '' : resCount.message,
dbName
});
addSqlLog({
id: `log-${Date.now()}-data`,
timestamp: Date.now(),
@@ -107,36 +93,104 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
dbName
});
if (pCols) {
const resCols = await pCols;
if (resCols.success) {
const pks = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
setPkColumns(pks);
if (pkColumns.length === 0) {
const pkKey = `${tab.connectionId}|${dbName}|${tableName}`;
if (pkKeyRef.current !== pkKey) {
pkKeyRef.current = pkKey;
const pkSeq = ++pkSeqRef.current;
DBGetColumns(config as any, dbName, tableName)
.then((resCols: any) => {
if (pkSeqRef.current !== pkSeq) return;
if (pkKeyRef.current !== pkKey) return;
if (!resCols?.success) return;
const pks = (resCols.data as ColumnDefinition[]).filter((c: any) => c.key === 'PRI').map((c: any) => c.name);
setPkColumns(pks);
})
.catch(() => {
if (pkSeqRef.current !== pkSeq) return;
if (pkKeyRef.current !== pkKey) return;
});
}
}
let totalRecords = 0;
if (resCount.success && Array.isArray(resCount.data) && resCount.data.length > 0) {
totalRecords = Number(resCount.data[0]['total']);
}
if (resData.success) {
let resultData = resData.data as any[];
if (!Array.isArray(resultData)) resultData = [];
const hasMore = resultData.length > size;
if (hasMore) resultData = resultData.slice(0, size);
let fieldNames = resData.fields || [];
if (fieldNames.length === 0 && resultData.length > 0) {
fieldNames = Object.keys(resultData[0]);
}
if (fetchSeqRef.current !== seq) return;
setColumnNames(fieldNames);
setData(resultData.map((row: any, i: number) => ({ ...row, key: `row-${i}` })));
setPagination(prev => ({ ...prev, current: page, pageSize: size, total: totalRecords }));
resultData.forEach((row: any, i: number) => {
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = `row-${offset + i}`;
});
setData(resultData);
const countKey = `${tab.connectionId}|${dbName}|${tableName}|${whereSQL}`;
const derivedTotalKnown = !hasMore;
const derivedTotal = derivedTotalKnown ? offset + resultData.length : page * size + 1;
if (derivedTotalKnown) countKeyRef.current = countKey;
setPagination(prev => {
if (derivedTotalKnown) {
return { ...prev, current: page, pageSize: size, total: derivedTotal, totalKnown: true };
}
if (prev.totalKnown && countKeyRef.current === countKey) {
return { ...prev, current: page, pageSize: size };
}
return { ...prev, current: page, pageSize: size, total: derivedTotal, totalKnown: false };
});
if (!derivedTotalKnown) {
if (countKeyRef.current !== countKey) {
countKeyRef.current = countKey;
const countSeq = ++countSeqRef.current;
const countStart = Date.now();
// 大表 COUNT(*) 可能非常慢,且在部分运行时环境下会影响后续操作响应;
// 这里为统计请求设置更短的超时,避免“后台统计”长期占用资源。
const countConfig: any = { ...(config as any), timeout: 5 };
DBQuery(countConfig, dbName, countSql)
.then((resCount: any) => {
const countDuration = Date.now() - countStart;
addSqlLog({
id: `log-${Date.now()}-count`,
timestamp: Date.now(),
sql: countSql,
status: resCount.success ? 'success' : 'error',
duration: countDuration,
message: resCount.success ? '' : resCount.message,
dbName
});
if (countSeqRef.current !== countSeq) return;
if (countKeyRef.current !== countKey) return;
if (!resCount.success) return;
if (!Array.isArray(resCount.data) || resCount.data.length === 0) return;
const total = Number(resCount.data[0]?.['total']);
if (!Number.isFinite(total) || total < 0) return;
setPagination(prev => ({ ...prev, total, totalKnown: true }));
})
.catch(() => {
if (countSeqRef.current !== countSeq) return;
if (countKeyRef.current !== countKey) return;
// 统计失败不影响主流程,不弹窗;可在日志里查看。
});
}
}
} else {
message.error(resData.message);
}
} catch (e: any) {
if (fetchSeqRef.current !== seq) return;
message.error("Error fetching data: " + e.message);
addSqlLog({
id: `log-${Date.now()}-error`,
@@ -148,7 +202,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
dbName
});
}
setLoading(false);
if (fetchSeqRef.current === seq) setLoading(false);
}, [connections, tab, sortInfo, filterConditions, pkColumns.length]);
// Depend on pkColumns.length to avoid loop? No, pkColumns is updated inside.
// Actually, 'pkColumns' state shouldn't trigger re-fetch.
@@ -158,7 +212,9 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
// So it's fine.
// Handlers memoized
const handleReload = useCallback(() => fetchData(), [fetchData]);
const handleReload = useCallback(() => {
fetchData(pagination.current, pagination.pageSize);
}, [fetchData, pagination.current, pagination.pageSize]);
const handleSort = useCallback((field: string, order: string) => setSortInfo({ columnKey: field, order }), []);
const handlePageChange = useCallback((page: number, size: number) => fetchData(page, size), [fetchData]);
const handleToggleFilter = useCallback(() => setShowFilter(prev => !prev), []);
@@ -169,7 +225,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
}, [tab, sortInfo, filterConditions]); // Initial load and re-load on sort/filter
return (
<div style={{ height: '100%', width: '100%', overflow: 'hidden' }}>
<div style={{ flex: '1 1 auto', minHeight: 0, height: '100%', width: '100%', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<DataGrid
data={data}
columnNames={columnNames}

View File

@@ -10,7 +10,9 @@ interface LogPanelProps {
}
const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) => {
const { sqlLogs, clearSqlLogs, darkMode } = useStore();
const sqlLogs = useStore(state => state.sqlLogs);
const clearSqlLogs = useStore(state => state.clearSqlLogs);
const darkMode = useStore(state => state.darkMode);
const columns = [
{
@@ -111,4 +113,4 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
);
};
export default LogPanel;
export default LogPanel;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
import React, { useState, useCallback, useRef } from 'react';
import { Button, Space, message } from 'antd';
import { PlayCircleOutlined, ClearOutlined } from '@ant-design/icons';
import { useStore } from '../store';
import Editor, { OnMount } from '@monaco-editor/react';
interface RedisCommandEditorProps {
connectionId: string;
redisDB: number;
}
interface CommandResult {
command: string;
result: any;
error?: string;
timestamp: number;
}
const RedisCommandEditor: React.FC<RedisCommandEditorProps> = ({ connectionId, redisDB }) => {
const { connections } = useStore();
const connection = connections.find(c => c.id === connectionId);
const [command, setCommand] = useState('');
const [results, setResults] = useState<CommandResult[]>([]);
const [loading, setLoading] = useState(false);
const editorRef = useRef<any>(null);
const getConfig = useCallback(() => {
if (!connection) return null;
return {
...connection.config,
port: Number(connection.config.port),
password: connection.config.password || "",
useSSH: connection.config.useSSH || false,
ssh: connection.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
redisDB: redisDB
};
}, [connection, redisDB]);
const handleEditorMount: OnMount = (editor) => {
editorRef.current = editor;
// Add keyboard shortcut for execute
editor.addCommand(
// Ctrl/Cmd + Enter
2048 | 3, // KeyMod.CtrlCmd | KeyCode.Enter
() => handleExecute()
);
};
const handleExecute = async () => {
const config = getConfig();
if (!config) return;
const cmdToExecute = command.trim();
if (!cmdToExecute) {
message.warning('请输入命令');
return;
}
// Support multiple commands separated by newlines
const commands = cmdToExecute.split('\n').filter(c => c.trim() && !c.trim().startsWith('//') && !c.trim().startsWith('#'));
setLoading(true);
const newResults: CommandResult[] = [];
for (const cmd of commands) {
const trimmedCmd = cmd.trim();
if (!trimmedCmd) continue;
try {
const res = await (window as any).go.app.App.RedisExecuteCommand(config, trimmedCmd);
newResults.push({
command: trimmedCmd,
result: res.success ? res.data : null,
error: res.success ? undefined : res.message,
timestamp: Date.now()
});
} catch (e: any) {
newResults.push({
command: trimmedCmd,
result: null,
error: e?.message || String(e),
timestamp: Date.now()
});
}
}
setResults(prev => [...newResults, ...prev]);
setLoading(false);
};
const handleClear = () => {
setResults([]);
};
const formatResult = (result: any): string => {
if (result === null || result === undefined) {
return '(nil)';
}
if (typeof result === 'string') {
return `"${result}"`;
}
if (typeof result === 'number') {
return `(integer) ${result}`;
}
if (Array.isArray(result)) {
if (result.length === 0) {
return '(empty array)';
}
return result.map((item, index) => `${index + 1}) ${formatResult(item)}`).join('\n');
}
if (typeof result === 'object') {
return JSON.stringify(result, null, 2);
}
return String(result);
};
if (!connection) {
return <div style={{ padding: 20 }}></div>;
}
return (
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
{/* Command Input */}
<div style={{ borderBottom: '1px solid #f0f0f0' }}>
<div style={{ padding: '8px 12px', borderBottom: '1px solid #f0f0f0', display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
<Space>
<span style={{ fontWeight: 500 }}>Redis </span>
<span style={{ color: '#999', fontSize: 12 }}>db{redisDB}</span>
</Space>
<Space>
<Button
type="primary"
icon={<PlayCircleOutlined />}
onClick={handleExecute}
loading={loading}
>
(Ctrl+Enter)
</Button>
<Button icon={<ClearOutlined />} onClick={handleClear}></Button>
</Space>
</div>
<Editor
height="150px"
defaultLanguage="plaintext"
value={command}
onChange={(value) => setCommand(value || '')}
onMount={handleEditorMount}
options={{
minimap: { enabled: false },
lineNumbers: 'on',
fontSize: 14,
wordWrap: 'on',
scrollBeyondLastLine: false,
automaticLayout: true,
tabSize: 2
}}
/>
</div>
{/* Results */}
<div style={{ flex: 1, overflow: 'auto', background: '#1e1e1e', color: '#d4d4d4', fontFamily: 'monospace' }}>
{results.length === 0 ? (
<div style={{ padding: 20, color: '#666', textAlign: 'center' }}>
Redis Ctrl+Enter
<br />
<span style={{ fontSize: 12 }}></span>
</div>
) : (
results.map((item, index) => (
<div key={item.timestamp + index} style={{ padding: '8px 12px', borderBottom: '1px solid #333' }}>
<div style={{ color: '#569cd6', marginBottom: 4 }}>
&gt; {item.command}
</div>
{item.error ? (
<div style={{ color: '#f14c4c', whiteSpace: 'pre-wrap' }}>
(error) {item.error}
</div>
) : (
<div style={{ color: '#ce9178', whiteSpace: 'pre-wrap' }}>
{formatResult(item.result)}
</div>
)}
</div>
))
)}
</div>
{/* Common Commands Help */}
<div style={{ padding: '8px 12px', borderTop: '1px solid #f0f0f0', background: '#fafafa', fontSize: 12, color: '#666' }}>
:
<span style={{ marginLeft: 8 }}>
<code>KEYS *</code> |
<code style={{ marginLeft: 8 }}>GET key</code> |
<code style={{ marginLeft: 8 }}>SET key value</code> |
<code style={{ marginLeft: 8 }}>HGETALL key</code> |
<code style={{ marginLeft: 8 }}>INFO</code> |
<code style={{ marginLeft: 8 }}>DBSIZE</code>
</span>
</div>
</div>
);
};
export default RedisCommandEditor;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,22 @@
import React, { useMemo } from 'react';
import { Tabs, Button } from 'antd';
import { Tabs, Dropdown } from 'antd';
import type { MenuProps } from 'antd';
import { useStore } from '../store';
import DataViewer from './DataViewer';
import QueryEditor from './QueryEditor';
import TableDesigner from './TableDesigner';
import RedisViewer from './RedisViewer';
import RedisCommandEditor from './RedisCommandEditor';
const TabManager: React.FC = () => {
const { tabs, activeTabId, setActiveTab, closeTab } = useStore();
const tabs = useStore(state => state.tabs);
const activeTabId = useStore(state => state.activeTabId);
const setActiveTab = useStore(state => state.setActiveTab);
const closeTab = useStore(state => state.closeTab);
const closeOtherTabs = useStore(state => state.closeOtherTabs);
const closeTabsToLeft = useStore(state => state.closeTabsToLeft);
const closeTabsToRight = useStore(state => state.closeTabsToRight);
const closeAllTabs = useStore(state => state.closeAllTabs);
const onChange = (newActiveKey: string) => {
setActiveTab(newActiveKey);
@@ -18,7 +28,7 @@ const TabManager: React.FC = () => {
}
};
const items = useMemo(() => tabs.map(tab => {
const items = useMemo(() => tabs.map((tab, index) => {
let content;
if (tab.type === 'query') {
content = <QueryEditor tab={tab} />;
@@ -26,28 +36,100 @@ const TabManager: React.FC = () => {
content = <DataViewer tab={tab} />;
} else if (tab.type === 'design') {
content = <TableDesigner tab={tab} />;
} else if (tab.type === 'redis-keys') {
content = <RedisViewer connectionId={tab.connectionId} redisDB={tab.redisDB ?? 0} />;
} else if (tab.type === 'redis-command') {
content = <RedisCommandEditor connectionId={tab.connectionId} redisDB={tab.redisDB ?? 0} />;
}
const menuItems: MenuProps['items'] = [
{
key: 'close-other',
label: '关闭其他页',
disabled: tabs.length <= 1,
onClick: () => closeOtherTabs(tab.id),
},
{
key: 'close-left',
label: '关闭左侧',
disabled: index === 0,
onClick: () => closeTabsToLeft(tab.id),
},
{
key: 'close-right',
label: '关闭右侧',
disabled: index === tabs.length - 1,
onClick: () => closeTabsToRight(tab.id),
},
{ type: 'divider' },
{
key: 'close-all',
label: '关闭所有',
disabled: tabs.length === 0,
onClick: () => closeAllTabs(),
},
];
return {
label: tab.title,
label: (
<Dropdown menu={{ items: menuItems }} trigger={['contextMenu']}>
<span onContextMenu={(e) => e.preventDefault()}>{tab.title}</span>
</Dropdown>
),
key: tab.id,
children: content,
};
}), [tabs]);
}), [tabs, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
return (
<>
<style>{`
.ant-tabs-content { height: 100%; }
.ant-tabs-tabpane { height: 100%; }
.main-tabs {
height: 100%;
flex: 1 1 auto;
min-height: 0;
display: flex;
flex-direction: column;
overflow: hidden;
}
.main-tabs .ant-tabs-nav {
flex: 0 0 auto;
}
.main-tabs .ant-tabs-content-holder {
flex: 1 1 auto;
min-height: 0;
overflow: hidden;
display: flex;
flex-direction: column;
}
.main-tabs .ant-tabs-content {
flex: 1 1 auto;
min-height: 0;
display: flex;
flex-direction: column;
}
.main-tabs .ant-tabs-tabpane {
flex: 1 1 auto;
min-height: 0;
display: flex;
flex-direction: column;
overflow: hidden;
}
.main-tabs .ant-tabs-tabpane > div {
flex: 1 1 auto;
min-height: 0;
}
.main-tabs .ant-tabs-tabpane-hidden {
display: none !important;
}
`}</style>
<Tabs
className="main-tabs"
type="editable-card"
onChange={onChange}
activeKey={activeTabId || undefined}
onEdit={onEdit}
items={items}
style={{ height: '100%' }}
hideAdd
/>
</>

View File

@@ -550,7 +550,6 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
<div ref={containerRef} className="table-designer-wrapper" style={{ height: '100%', overflow: 'hidden', position: 'relative' }}>
<style>{`
.table-designer-wrapper .ant-table-body {
height: ${tableHeight}px !important;
max-height: ${tableHeight}px !important;
}
`}</style>

View File

@@ -21,6 +21,7 @@ interface AppState {
savedQueries: SavedQuery[];
darkMode: boolean;
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
queryOptions: { maxRows: number };
sqlLogs: SqlLog[];
addConnection: (conn: SavedConnection) => void;
@@ -29,6 +30,10 @@ interface AppState {
addTab: (tab: TabData) => void;
closeTab: (id: string) => void;
closeOtherTabs: (id: string) => void;
closeTabsToLeft: (id: string) => void;
closeTabsToRight: (id: string) => void;
closeAllTabs: () => void;
setActiveTab: (id: string) => void;
setActiveContext: (context: { connectionId: string; dbName: string } | null) => void;
@@ -37,6 +42,7 @@ interface AppState {
toggleDarkMode: () => void;
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
setQueryOptions: (options: Partial<{ maxRows: number }>) => void;
addSqlLog: (log: SqlLog) => void;
clearSqlLogs: () => void;
@@ -52,6 +58,7 @@ export const useStore = create<AppState>()(
savedQueries: [],
darkMode: false,
sqlFormatOptions: { keywordCase: 'upper' },
queryOptions: { maxRows: 5000 },
sqlLogs: [],
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
@@ -79,6 +86,30 @@ export const useStore = create<AppState>()(
}
return { tabs: newTabs, activeTabId: newActiveId };
}),
closeOtherTabs: (id) => set((state) => {
const keep = state.tabs.find(t => t.id === id);
if (!keep) return state;
return { tabs: [keep], activeTabId: id };
}),
closeTabsToLeft: (id) => set((state) => {
const index = state.tabs.findIndex(t => t.id === id);
if (index === -1) return state;
const newTabs = state.tabs.slice(index);
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
}),
closeTabsToRight: (id) => set((state) => {
const index = state.tabs.findIndex(t => t.id === id);
if (index === -1) return state;
const newTabs = state.tabs.slice(0, index + 1);
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
}),
closeAllTabs: () => set(() => ({ tabs: [], activeTabId: null })),
setActiveTab: (id) => set({ activeTabId: id }),
setActiveContext: (context) => set({ activeContext: context }),
@@ -96,13 +127,14 @@ export const useStore = create<AppState>()(
toggleDarkMode: () => set((state) => ({ darkMode: !state.darkMode })),
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
clearSqlLogs: () => set({ sqlLogs: [] }),
}),
{
name: 'lite-db-storage', // name of the item in the storage (must be unique)
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions }), // Don't persist logs
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions, queryOptions: state.queryOptions }), // Don't persist logs
}
)
);
);

View File

@@ -15,6 +15,7 @@ export interface ConnectionConfig {
database?: string;
useSSH?: boolean;
ssh?: SSHConfig;
redisDB?: number; // Redis database index (0-15)
}
export interface SavedConnection {
@@ -22,6 +23,7 @@ export interface SavedConnection {
name: string;
config: ConnectionConfig;
includeDatabases?: string[];
includeRedisDatabases?: number[]; // Redis databases to show (0-15)
}
export interface ColumnDefinition {
@@ -60,13 +62,14 @@ export interface TriggerDefinition {
export interface TabData {
id: string;
title: string;
type: 'query' | 'table' | 'design';
type: 'query' | 'table' | 'design' | 'redis-keys' | 'redis-command';
connectionId: string;
dbName?: string;
tableName?: string;
query?: string;
initialTab?: string;
readOnly?: boolean;
redisDB?: number; // Redis database index for redis tabs
}
export interface DatabaseNode {
@@ -85,3 +88,32 @@ export interface SavedQuery {
dbName: string;
createdAt: number;
}
// Redis types
export interface RedisKeyInfo {
key: string;
type: string;
ttl: number;
}
export interface RedisScanResult {
keys: RedisKeyInfo[];
cursor: number;
}
export interface RedisValue {
type: 'string' | 'hash' | 'list' | 'set' | 'zset';
ttl: number;
value: any;
length: number;
}
export interface RedisDBInfo {
index: number;
keys: number;
}
export interface ZSetMember {
member: string;
score: number;
}

198
frontend/src/utils/sql.ts Normal file
View File

@@ -0,0 +1,198 @@
export type FilterCondition = {
id?: number;
column?: string;
op?: string;
value?: string;
value2?: string;
};
const normalizeIdentPart = (ident: string) => {
let raw = (ident || '').trim();
if (!raw) return raw;
const first = raw[0];
const last = raw[raw.length - 1];
if ((first === '"' && last === '"') || (first === '`' && last === '`')) {
raw = raw.slice(1, -1).trim();
}
raw = raw.replace(/["`]/g, '').trim();
return raw;
};
// 检查标识符是否需要引号(包含特殊字符或是保留字)
const needsQuote = (ident: string): boolean => {
if (!ident) return false;
// 如果包含特殊字符(非字母、数字、下划线)则需要引号
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(ident)) return true;
// 常见 SQL 保留字列表(简化版)
const reserved = ['select', 'from', 'where', 'table', 'index', 'user', 'order', 'group', 'by', 'limit', 'offset', 'and', 'or', 'not', 'null', 'true', 'false', 'key', 'primary', 'foreign', 'references', 'default', 'constraint', 'create', 'drop', 'alter', 'insert', 'update', 'delete', 'set', 'values', 'into', 'join', 'left', 'right', 'inner', 'outer', 'on', 'as', 'is', 'in', 'like', 'between', 'case', 'when', 'then', 'else', 'end', 'having', 'distinct', 'all', 'any', 'exists', 'union', 'except', 'intersect'];
return reserved.includes(ident.toLowerCase());
};
export const quoteIdentPart = (dbType: string, ident: string) => {
const raw = normalizeIdentPart(ident);
if (!raw) return raw;
const dbTypeLower = (dbType || '').toLowerCase();
if (dbTypeLower === 'mysql') {
return `\`${raw.replace(/`/g, '``')}\``;
}
// 对于 KingBase/PostgreSQL只在必要时加引号
if (dbTypeLower === 'kingbase' || dbTypeLower === 'postgres') {
if (needsQuote(raw)) {
return `"${raw.replace(/"/g, '""')}"`;
}
// 不加引号,保持原样(数据库会自动转小写处理)
return raw;
}
// 其他数据库默认加双引号
return `"${raw.replace(/"/g, '""')}"`;
};
export const quoteQualifiedIdent = (dbType: string, ident: string) => {
const raw = (ident || '').trim();
if (!raw) return raw;
const parts = raw.split('.').map(normalizeIdentPart).filter(Boolean);
if (parts.length <= 1) return quoteIdentPart(dbType, raw);
return parts.map(p => quoteIdentPart(dbType, p)).join('.');
};
export const escapeLiteral = (val: string) => (val || '').replace(/'/g, "''");
export const parseListValues = (val: string) => {
const raw = (val || '').trim();
if (!raw) return [];
return raw
.split(/[\n,]+/)
.map(s => s.trim())
.filter(Boolean);
};
export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) => {
const whereParts: string[] = [];
(conditions || []).forEach((cond) => {
const op = (cond?.op || '').trim();
const column = (cond?.column || '').trim();
const value = (cond?.value ?? '').toString();
const value2 = (cond?.value2 ?? '').toString();
if (op === 'CUSTOM') {
const expr = value.trim();
if (expr) whereParts.push(`(${expr})`);
return;
}
if (!column) return;
const col = quoteIdentPart(dbType, column);
switch (op) {
case 'IS_NULL':
whereParts.push(`${col} IS NULL`);
return;
case 'IS_NOT_NULL':
whereParts.push(`${col} IS NOT NULL`);
return;
case 'IS_EMPTY':
// 兼容:空值通常理解为 NULL 或空字符串
whereParts.push(`(${col} IS NULL OR ${col} = '')`);
return;
case 'IS_NOT_EMPTY':
whereParts.push(`(${col} IS NOT NULL AND ${col} <> '')`);
return;
case 'BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'NOT_BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} IN (${list})`);
return;
}
case 'NOT_IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} NOT IN (${list})`);
return;
}
case 'CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'NOT_CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'NOT_STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}'`);
return;
}
case 'NOT_ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
return;
}
case '=':
case '!=':
case '<':
case '<=':
case '>':
case '>=': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
return;
}
default: {
// 兼容旧值LIKE
if (op.toUpperCase() === 'LIKE') {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
}
}
});
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
};

View File

@@ -2,6 +2,7 @@
// This file is automatically generated. DO NOT EDIT
import {connection} from '../models';
import {sync} from '../models';
import {redis} from '../models';
export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:connection.ChangeSet):Promise<connection.QueryResult>;
@@ -29,10 +30,20 @@ export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,a
export function DataSync(arg1:sync.SyncConfig):Promise<sync.SyncResult>;
export function DataSyncAnalyze(arg1:sync.SyncConfig):Promise<connection.QueryResult>;
export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Promise<connection.QueryResult>;
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise<connection.QueryResult>;
export function ExportQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string,arg5:string):Promise<connection.QueryResult>;
export function ExportTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function ExportTablesSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>,arg4:boolean):Promise<connection.QueryResult>;
export function ImportConfigFile():Promise<connection.QueryResult>;
export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
@@ -49,4 +60,46 @@ export function MySQLShowCreateTable(arg1:connection.ConnectionConfig,arg2:strin
export function OpenSQLFile():Promise<connection.QueryResult>;
export function RedisConnect(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function RedisDeleteHashField(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
export function RedisDeleteKeys(arg1:connection.ConnectionConfig,arg2:Array<string>):Promise<connection.QueryResult>;
export function RedisExecuteCommand(arg1:connection.ConnectionConfig,arg2:string):Promise<connection.QueryResult>;
export function RedisFlushDB(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function RedisGetDatabases(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function RedisGetServerInfo(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function RedisGetValue(arg1:connection.ConnectionConfig,arg2:string):Promise<connection.QueryResult>;
export function RedisListPush(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
export function RedisListSet(arg1:connection.ConnectionConfig,arg2:string,arg3:number,arg4:string):Promise<connection.QueryResult>;
export function RedisRenameKey(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
export function RedisScanKeys(arg1:connection.ConnectionConfig,arg2:string,arg3:number,arg4:number):Promise<connection.QueryResult>;
export function RedisSelectDB(arg1:connection.ConnectionConfig,arg2:number):Promise<connection.QueryResult>;
export function RedisSetAdd(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
export function RedisSetHashField(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function RedisSetRemove(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
export function RedisSetString(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:number):Promise<connection.QueryResult>;
export function RedisSetTTL(arg1:connection.ConnectionConfig,arg2:string,arg3:number):Promise<connection.QueryResult>;
export function RedisTestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function RedisZSetAdd(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<redis.ZSetMember>):Promise<connection.QueryResult>;
export function RedisZSetRemove(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
export function TestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;

View File

@@ -54,14 +54,34 @@ export function DataSync(arg1) {
return window['go']['app']['App']['DataSync'](arg1);
}
export function DataSyncAnalyze(arg1) {
return window['go']['app']['App']['DataSyncAnalyze'](arg1);
}
export function DataSyncPreview(arg1, arg2, arg3) {
return window['go']['app']['App']['DataSyncPreview'](arg1, arg2, arg3);
}
export function ExportData(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4);
}
export function ExportDatabaseSQL(arg1, arg2, arg3) {
return window['go']['app']['App']['ExportDatabaseSQL'](arg1, arg2, arg3);
}
export function ExportQuery(arg1, arg2, arg3, arg4, arg5) {
return window['go']['app']['App']['ExportQuery'](arg1, arg2, arg3, arg4, arg5);
}
export function ExportTable(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportTable'](arg1, arg2, arg3, arg4);
}
export function ExportTablesSQL(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportTablesSQL'](arg1, arg2, arg3, arg4);
}
export function ImportConfigFile() {
return window['go']['app']['App']['ImportConfigFile']();
}
@@ -94,6 +114,90 @@ export function OpenSQLFile() {
return window['go']['app']['App']['OpenSQLFile']();
}
export function RedisConnect(arg1) {
return window['go']['app']['App']['RedisConnect'](arg1);
}
export function RedisDeleteHashField(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisDeleteHashField'](arg1, arg2, arg3);
}
export function RedisDeleteKeys(arg1, arg2) {
return window['go']['app']['App']['RedisDeleteKeys'](arg1, arg2);
}
export function RedisExecuteCommand(arg1, arg2) {
return window['go']['app']['App']['RedisExecuteCommand'](arg1, arg2);
}
export function RedisFlushDB(arg1) {
return window['go']['app']['App']['RedisFlushDB'](arg1);
}
export function RedisGetDatabases(arg1) {
return window['go']['app']['App']['RedisGetDatabases'](arg1);
}
export function RedisGetServerInfo(arg1) {
return window['go']['app']['App']['RedisGetServerInfo'](arg1);
}
export function RedisGetValue(arg1, arg2) {
return window['go']['app']['App']['RedisGetValue'](arg1, arg2);
}
export function RedisListPush(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisListPush'](arg1, arg2, arg3);
}
export function RedisListSet(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['RedisListSet'](arg1, arg2, arg3, arg4);
}
export function RedisRenameKey(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisRenameKey'](arg1, arg2, arg3);
}
export function RedisScanKeys(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['RedisScanKeys'](arg1, arg2, arg3, arg4);
}
export function RedisSelectDB(arg1, arg2) {
return window['go']['app']['App']['RedisSelectDB'](arg1, arg2);
}
export function RedisSetAdd(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisSetAdd'](arg1, arg2, arg3);
}
export function RedisSetHashField(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['RedisSetHashField'](arg1, arg2, arg3, arg4);
}
export function RedisSetRemove(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisSetRemove'](arg1, arg2, arg3);
}
export function RedisSetString(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['RedisSetString'](arg1, arg2, arg3, arg4);
}
export function RedisSetTTL(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisSetTTL'](arg1, arg2, arg3);
}
export function RedisTestConnection(arg1) {
return window['go']['app']['App']['RedisTestConnection'](arg1);
}
export function RedisZSetAdd(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisZSetAdd'](arg1, arg2, arg3);
}
export function RedisZSetRemove(arg1, arg2, arg3) {
return window['go']['app']['App']['RedisZSetRemove'](arg1, arg2, arg3);
}
export function TestConnection(arg1) {
return window['go']['app']['App']['TestConnection'](arg1);
}

View File

@@ -80,6 +80,7 @@ export namespace connection {
driver?: string;
dsn?: string;
timeout?: number;
redisDB?: number;
static createFrom(source: any = {}) {
return new ConnectionConfig(source);
@@ -98,6 +99,7 @@ export namespace connection {
this.driver = source["driver"];
this.dsn = source["dsn"];
this.timeout = source["timeout"];
this.redisDB = source["redisDB"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
@@ -140,13 +142,58 @@ export namespace connection {
}
export namespace redis {
export class ZSetMember {
member: string;
score: number;
static createFrom(source: any = {}) {
return new ZSetMember(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.member = source["member"];
this.score = source["score"];
}
}
}
export namespace sync {
export class TableOptions {
insert?: boolean;
update?: boolean;
delete?: boolean;
selectedInsertPks?: string[];
selectedUpdatePks?: string[];
selectedDeletePks?: string[];
static createFrom(source: any = {}) {
return new TableOptions(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.insert = source["insert"];
this.update = source["update"];
this.delete = source["delete"];
this.selectedInsertPks = source["selectedInsertPks"];
this.selectedUpdatePks = source["selectedUpdatePks"];
this.selectedDeletePks = source["selectedDeletePks"];
}
}
export class SyncConfig {
sourceConfig: connection.ConnectionConfig;
targetConfig: connection.ConnectionConfig;
tables: string[];
content?: string;
mode: string;
jobId?: string;
autoAddColumns?: boolean;
tableOptions?: Record<string, TableOptions>;
static createFrom(source: any = {}) {
return new SyncConfig(source);
@@ -157,7 +204,11 @@ export namespace sync {
this.sourceConfig = this.convertValues(source["sourceConfig"], connection.ConnectionConfig);
this.targetConfig = this.convertValues(source["targetConfig"], connection.ConnectionConfig);
this.tables = source["tables"];
this.content = source["content"];
this.mode = source["mode"];
this.jobId = source["jobId"];
this.autoAddColumns = source["autoAddColumns"];
this.tableOptions = this.convertValues(source["tableOptions"], TableOptions, true);
}
convertValues(a: any, classs: any, asMap: boolean = false): any {

3
go.mod
View File

@@ -7,6 +7,7 @@ require (
gitee.com/chunanyong/dm v1.8.22
github.com/go-sql-driver/mysql v1.9.3
github.com/lib/pq v1.11.1
github.com/redis/go-redis/v9 v9.17.3
github.com/sijms/go-ora/v2 v2.9.0
github.com/wailsapp/wails/v2 v2.11.0
golang.org/x/crypto v0.47.0
@@ -16,6 +17,8 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/bep/debounce v1.2.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect

10
go.sum
View File

@@ -6,8 +6,16 @@ gitee.com/chunanyong/dm v1.8.22 h1:H7fsrnUIvEA0jlDWew7vwELry1ff+tLMIu2Fk2cIBSg=
gitee.com/chunanyong/dm v1.8.22/go.mod h1:EPRJnuPFgbyOFgJ0TRYCTGzhq+ZT4wdyaj/GW/LLcNg=
github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY=
github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
@@ -61,6 +69,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4=
github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=

View File

@@ -10,23 +10,31 @@ import (
"net"
"strings"
"sync"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
)
const dbCachePingInterval = 30 * time.Second
type cachedDatabase struct {
inst db.Database
lastPing time.Time
}
// App struct
type App struct {
ctx context.Context
dbCache map[string]db.Database // Cache for DB connections
mu sync.Mutex // Mutex for cache access
dbCache map[string]cachedDatabase // Cache for DB connections
mu sync.RWMutex // Mutex for cache access
}
// NewApp creates a new App application struct
func NewApp() *App {
return &App{
dbCache: make(map[string]db.Database),
dbCache: make(map[string]cachedDatabase),
}
}
@@ -44,10 +52,12 @@ func (a *App) Shutdown(ctx context.Context) {
a.mu.Lock()
defer a.mu.Unlock()
for _, dbInst := range a.dbCache {
if err := dbInst.Close(); err != nil {
if err := dbInst.inst.Close(); err != nil {
logger.Error(err, "关闭数据库连接失败")
}
}
// Close all Redis connections
CloseAllRedisClients()
logger.Infof("资源释放完成,应用已关闭")
logger.Close()
}
@@ -134,35 +144,63 @@ func formatConnSummary(config connection.ConnectionConfig) string {
return b.String()
}
func (a *App) getDatabaseForcePing(config connection.ConnectionConfig) (db.Database, error) {
return a.getDatabaseWithPing(config, true)
}
// Helper: Get or create a database connection
func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, error) {
return a.getDatabaseWithPing(config, false)
}
func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing bool) (db.Database, error) {
key := getCacheKey(config)
shortKey := key
if len(shortKey) > 12 {
shortKey = shortKey[:12]
}
if config.UseSSH && config.Type != "mysql" {
logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config))
}
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
a.mu.Lock()
defer a.mu.Unlock()
a.mu.RLock()
entry, ok := a.dbCache[key]
a.mu.RUnlock()
if ok {
needPing := forcePing
if !needPing {
lastPing := entry.lastPing
if lastPing.IsZero() || time.Since(lastPing) >= dbCachePingInterval {
needPing = true
}
}
if dbInst, ok := a.dbCache[key]; ok {
logger.Infof("命中连接缓存开始检测可用性缓存Key=%s", shortKey)
if err := dbInst.Ping(); err == nil {
logger.Infof("缓存连接可用缓存Key=%s", shortKey)
return dbInst, nil
if !needPing {
return entry.inst, nil
}
if err := entry.inst.Ping(); err == nil {
// Update lastPing (best effort)
a.mu.Lock()
if cur, exists := a.dbCache[key]; exists && cur.inst == entry.inst {
cur.lastPing = time.Now()
a.dbCache[key] = cur
}
a.mu.Unlock()
return entry.inst, nil
} else {
logger.Error(err, "缓存连接不可用准备重建缓存Key=%s", shortKey)
logger.Error(err, "缓存连接不可用,准备重建:%s 缓存Key=%s", formatConnSummary(config), shortKey)
}
if err := dbInst.Close(); err != nil {
logger.Error(err, "关闭失效缓存连接失败缓存Key=%s", shortKey)
// Ping failed: remove cached instance (best effort)
a.mu.Lock()
if cur, exists := a.dbCache[key]; exists && cur.inst == entry.inst {
if err := cur.inst.Close(); err != nil {
logger.Error(err, "关闭失效缓存连接失败缓存Key=%s", shortKey)
}
delete(a.dbCache, key)
}
delete(a.dbCache, key)
a.mu.Unlock()
}
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
logger.Infof("创建数据库驱动实例:类型=%s 缓存Key=%s", config.Type, shortKey)
dbInst, err := db.NewDatabase(config.Type)
if err != nil {
@@ -176,7 +214,18 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
return nil, wrapped
}
a.dbCache[key] = dbInst
now := time.Now()
a.mu.Lock()
if existing, exists := a.dbCache[key]; exists && existing.inst != nil {
a.mu.Unlock()
// Prefer existing cached connection to avoid cache racing duplicates.
_ = dbInst.Close()
return existing.inst, nil
}
a.dbCache[key] = cachedDatabase{inst: dbInst, lastPing: now}
a.mu.Unlock()
logger.Infof("数据库连接成功并写入缓存:%s 缓存Key=%s", formatConnSummary(config), shortKey)
return dbInst, nil
}

View File

@@ -0,0 +1,56 @@
package app
import (
"strings"
"GoNavi-Wails/internal/connection"
)
func normalizeRunConfig(config connection.ConnectionConfig, dbName string) connection.ConnectionConfig {
runConfig := config
name := strings.TrimSpace(dbName)
if name == "" {
return runConfig
}
switch strings.ToLower(strings.TrimSpace(config.Type)) {
case "mysql", "postgres", "kingbase":
// 这些类型的 dbName 表示“数据库”,需要写入连接配置以选择目标库。
runConfig.Database = name
case "dameng":
// 达梦使用 schema 参数沿用现有行为dbName 表示 schema。
runConfig.Database = name
default:
// oracle: dbName 表示 schema/owner不能覆盖 config.Database服务名
// sqlite: 无需设置 Database
// custom: 语义不明确,避免污染缓存 key
}
return runConfig
}
func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string, tableName string) (string, string) {
rawTable := strings.TrimSpace(tableName)
rawDB := strings.TrimSpace(dbName)
if rawTable == "" {
return rawDB, rawTable
}
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
schema := strings.TrimSpace(parts[0])
table := strings.TrimSpace(parts[1])
if schema != "" && table != "" {
return schema, table
}
}
switch strings.ToLower(strings.TrimSpace(config.Type)) {
case "postgres", "kingbase":
// PG/金仓dbName 在 UI 里是“数据库”schema 需从 tableName 或使用默认 public。
return "public", rawTable
default:
// MySQLdbName 表示数据库Oracle/达梦dbName 表示 schema/owner。
return rawDB, rawTable
}
}

View File

@@ -1,18 +1,21 @@
package app
import (
"context"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/utils"
)
// Generic DB Methods
func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResult {
// getDatabase checks cache and Pings. If valid, reuses. If not, connects.
_, err := a.getDatabase(config)
// 连接测试需要强制 ping避免缓存命中但连接已失效时误判成功。
_, err := a.getDatabaseForcePing(config)
if err != nil {
logger.Error(err, "DBConnect 连接失败:%s", formatConnSummary(config))
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -23,7 +26,7 @@ func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResu
}
func (a *App) TestConnection(config connection.ConnectionConfig) connection.QueryResult {
_, err := a.getDatabase(config)
_, err := a.getDatabaseForcePing(config)
if err != nil {
logger.Error(err, "TestConnection 连接测试失败:%s", formatConnSummary(config))
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -83,10 +86,7 @@ func (a *App) MySQLShowCreateTable(config connection.ConnectionConfig, dbName st
}
func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -94,16 +94,39 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
return connection.QueryResult{Success: false, Message: err.Error()}
}
query = sanitizeSQLForPgLike(runConfig.Type, query)
timeoutSeconds := runConfig.Timeout
if timeoutSeconds <= 0 {
timeoutSeconds = 30
}
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
defer cancel()
lowerQuery := strings.TrimSpace(strings.ToLower(query))
if strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain") {
data, columns, err := dbInst.Query(query)
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, query)
} else {
data, columns, err = dbInst.Query(query)
}
if err != nil {
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: data, Fields: columns}
} else {
affected, err := dbInst.Exec(query)
var affected int64
if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, query)
} else {
affected, err = dbInst.Exec(query)
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -143,10 +166,7 @@ func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.Quer
}
func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -169,10 +189,7 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con
}
func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -180,7 +197,8 @@ func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName strin
return connection.QueryResult{Success: false, Message: err.Error()}
}
sqlStr, err := dbInst.GetCreateStatement(dbName, tableName)
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
sqlStr, err := dbInst.GetCreateStatement(schemaName, pureTableName)
if err != nil {
logger.Error(err, "DBShowCreateTable 获取建表语句失败:%s 表=%s", formatConnSummary(runConfig), tableName)
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -190,17 +208,15 @@ func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName strin
}
func (a *App) DBGetColumns(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
columns, err := dbInst.GetColumns(dbName, tableName)
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
columns, err := dbInst.GetColumns(schemaName, pureTableName)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -209,17 +225,15 @@ func (a *App) DBGetColumns(config connection.ConnectionConfig, dbName string, ta
}
func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
indexes, err := dbInst.GetIndexes(dbName, tableName)
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
indexes, err := dbInst.GetIndexes(schemaName, pureTableName)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -228,17 +242,15 @@ func (a *App) DBGetIndexes(config connection.ConnectionConfig, dbName string, ta
}
func (a *App) DBGetForeignKeys(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
fks, err := dbInst.GetForeignKeys(dbName, tableName)
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
fks, err := dbInst.GetForeignKeys(schemaName, pureTableName)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -247,17 +259,15 @@ func (a *App) DBGetForeignKeys(config connection.ConnectionConfig, dbName string
}
func (a *App) DBGetTriggers(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
triggers, err := dbInst.GetTriggers(dbName, tableName)
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
triggers, err := dbInst.GetTriggers(schemaName, pureTableName)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -266,10 +276,7 @@ func (a *App) DBGetTriggers(config connection.ConnectionConfig, dbName string, t
}
func (a *App) DBGetAllColumns(config connection.ConnectionConfig, dbName string) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {

View File

@@ -1,11 +1,16 @@
package app
import (
"bufio"
"encoding/csv"
"encoding/json"
"fmt"
"math"
"os"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
@@ -135,10 +140,7 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
return connection.QueryResult{Success: true, Message: "No data to import"}
}
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -164,21 +166,16 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
values = append(values, fmt.Sprintf("'%s'", vStr))
}
}
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)",
tableName,
strings.Join(cols, ", "),
strings.Join(values, ", "))
if runConfig.Type == "postgres" {
pgCols := make([]string, len(cols))
for i, c := range cols { pgCols[i] = fmt.Sprintf("\"%s\"", c) }
query = fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)",
tableName,
strings.Join(pgCols, ", "),
strings.Join(values, ", "))
quotedCols := make([]string, len(cols))
for i, c := range cols {
quotedCols[i] = quoteIdentByType(runConfig.Type, c)
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
quoteQualifiedIdentByType(runConfig.Type, tableName),
strings.Join(quotedCols, ", "),
strings.Join(values, ", "))
_, err := dbInst.Exec(query)
if err != nil {
errCount++
@@ -192,10 +189,7 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
}
func (a *App) ApplyChanges(config connection.ConnectionConfig, dbName, tableName string, changes connection.ChangeSet) connection.QueryResult {
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -207,10 +201,10 @@ func (a *App) ApplyChanges(config connection.ConnectionConfig, dbName, tableName
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "Changes applied successfully"}
return connection.QueryResult{Success: true, Message: "事务提交成功"}
}
return connection.QueryResult{Success: false, Message: "Batch updates not supported for this database type"}
return connection.QueryResult{Success: false, Message: "当前数据库类型不支持批量提交"}
}
func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tableName string, format string) connection.QueryResult {
@@ -223,20 +217,38 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
runConfig := config
if dbName != "" {
runConfig.Database = dbName
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
query := fmt.Sprintf("SELECT * FROM `%s`", tableName)
if runConfig.Type == "postgres" {
query = fmt.Sprintf("SELECT * FROM \"%s\"", tableName)
format = strings.ToLower(format)
if format == "sql" {
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
w := bufio.NewWriterSize(f, 1024*1024)
defer w.Flush()
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := dumpTableSQL(w, dbInst, runConfig, dbName, tableName, true); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := writeSQLFooter(w, runConfig); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(runConfig.Type, tableName))
data, columns, err := dbInst.Query(query)
if err != nil {
@@ -248,76 +260,340 @@ data, columns, err := dbInst.Query(query)
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
format = strings.ToLower(format)
var csvWriter *csv.Writer
var jsonEncoder *json.Encoder
var isJsonFirstRow = true
switch format {
case "csv", "xlsx":
f.Write([]byte{0xEF, 0xBB, 0xBF})
csvWriter = csv.NewWriter(f)
defer csvWriter.Flush()
if err := csvWriter.Write(columns); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
case "json":
f.WriteString("[\n")
jsonEncoder = json.NewEncoder(f)
jsonEncoder.SetIndent(" ", " ")
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
seps := make([]string, len(columns))
for i := range seps {
seps[i] = "---"
}
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
default:
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
}
for _, rowMap := range data {
record := make([]string, len(columns))
for i, col := range columns {
val := rowMap[col]
if val == nil {
record[i] = "NULL"
} else {
s := fmt.Sprintf("%v", val)
if format == "md" {
s = strings.ReplaceAll(s, "|", "\\|")
s = strings.ReplaceAll(s, "\n", "<br>")
}
record[i] = s
}
}
switch format {
case "csv", "xlsx":
if err := csvWriter.Write(record); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
case "json":
if !isJsonFirstRow {
f.WriteString(",\n")
}
if err := jsonEncoder.Encode(rowMap); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
isJsonFirstRow = false
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
}
}
if format == "json" {
f.WriteString("\n]")
if err := writeRowsToFile(f, data, columns, format); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
func (a *App) ExportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeData bool) connection.QueryResult {
safeDbName := strings.TrimSpace(dbName)
if safeDbName == "" {
safeDbName = "export"
}
suffix := "schema"
if includeData {
suffix = "backup"
}
defaultFilename := fmt.Sprintf("%s_%s_%dtables.sql", safeDbName, suffix, len(tableNames))
if len(tableNames) == 1 && strings.TrimSpace(tableNames[0]) != "" {
defaultFilename = fmt.Sprintf("%s_%s.sql", strings.TrimSpace(tableNames[0]), suffix)
}
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: "Export Tables (SQL)",
DefaultFilename: defaultFilename,
})
if err != nil || filename == "" {
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
tables := make([]string, 0, len(tableNames))
seen := make(map[string]struct{}, len(tableNames))
for _, t := range tableNames {
t = strings.TrimSpace(t)
if t == "" {
continue
}
if _, ok := seen[t]; ok {
continue
}
seen[t] = struct{}{}
tables = append(tables, t)
}
sort.Strings(tables)
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
w := bufio.NewWriterSize(f, 1024*1024)
defer w.Flush()
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
for _, t := range tables {
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
}
if err := writeSQLFooter(w, runConfig); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName string, includeData bool) connection.QueryResult {
safeDbName := strings.TrimSpace(dbName)
if safeDbName == "" {
return connection.QueryResult{Success: false, Message: "dbName required"}
}
suffix := "schema"
if includeData {
suffix = "backup"
}
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: fmt.Sprintf("Export %s (SQL)", safeDbName),
DefaultFilename: fmt.Sprintf("%s_%s.sql", safeDbName, suffix),
})
if err != nil || filename == "" {
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
tables, err := dbInst.GetTables(dbName)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
sort.Strings(tables)
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
w := bufio.NewWriterSize(f, 1024*1024)
defer w.Flush()
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
for _, t := range tables {
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
}
if err := writeSQLFooter(w, runConfig); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
func quoteIdentByType(dbType string, ident string) string {
if ident == "" {
return ident
}
switch dbType {
case "mysql":
return "`" + strings.ReplaceAll(ident, "`", "``") + "`"
default:
return `"` + strings.ReplaceAll(ident, `"`, `""`) + `"`
}
}
func quoteQualifiedIdentByType(dbType string, ident string) string {
raw := strings.TrimSpace(ident)
if raw == "" {
return raw
}
parts := strings.Split(raw, ".")
if len(parts) <= 1 {
return quoteIdentByType(dbType, raw)
}
quotedParts := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
quotedParts = append(quotedParts, quoteIdentByType(dbType, part))
}
if len(quotedParts) == 0 {
return quoteIdentByType(dbType, raw)
}
return strings.Join(quotedParts, ".")
}
func writeSQLHeader(w *bufio.Writer, config connection.ConnectionConfig, dbName string) error {
now := time.Now().Format("2006-01-02 15:04:05")
if _, err := w.WriteString(fmt.Sprintf("-- GoNavi SQL Export\n-- Time: %s\n", now)); err != nil {
return err
}
if strings.TrimSpace(dbName) != "" {
if _, err := w.WriteString(fmt.Sprintf("-- Database: %s\n\n", dbName)); err != nil {
return err
}
}
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" && strings.TrimSpace(dbName) != "" {
if _, err := w.WriteString(fmt.Sprintf("USE %s;\n\n", quoteIdentByType("mysql", dbName))); err != nil {
return err
}
if _, err := w.WriteString("SET FOREIGN_KEY_CHECKS=0;\n\n"); err != nil {
return err
}
}
return nil
}
func writeSQLFooter(w *bufio.Writer, config connection.ConnectionConfig) error {
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" {
if _, err := w.WriteString("\nSET FOREIGN_KEY_CHECKS=1;\n"); err != nil {
return err
}
}
return nil
}
func qualifyTable(schemaName, tableName string) string {
schemaName = strings.TrimSpace(schemaName)
tableName = strings.TrimSpace(tableName)
if schemaName == "" {
return tableName
}
return schemaName + "." + tableName
}
func ensureSQLTerminator(sql string) string {
trimmed := strings.TrimSpace(sql)
if trimmed == "" {
return sql
}
if strings.HasSuffix(trimmed, ";") {
return sql
}
return sql + ";"
}
func isMySQLHexLiteral(s string) bool {
if len(s) < 3 || !(strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X")) {
return false
}
for i := 2; i < len(s); i++ {
c := s[i]
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return false
}
}
return true
}
func formatSQLValue(dbType string, v interface{}) string {
if v == nil {
return "NULL"
}
switch val := v.(type) {
case bool:
if val {
return "1"
}
return "0"
case int:
return strconv.Itoa(val)
case int8, int16, int32, int64:
return fmt.Sprintf("%d", val)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", val)
case float32:
f := float64(val)
if math.IsNaN(f) || math.IsInf(f, 0) {
return "NULL"
}
return strconv.FormatFloat(f, 'f', -1, 32)
case float64:
if math.IsNaN(val) || math.IsInf(val, 0) {
return "NULL"
}
return strconv.FormatFloat(val, 'f', -1, 64)
case time.Time:
return "'" + val.Format("2006-01-02 15:04:05") + "'"
case string:
if strings.ToLower(strings.TrimSpace(dbType)) == "mysql" && isMySQLHexLiteral(val) {
return val
}
escaped := strings.ReplaceAll(val, "'", "''")
return "'" + escaped + "'"
default:
escaped := strings.ReplaceAll(fmt.Sprintf("%v", v), "'", "''")
return "'" + escaped + "'"
}
}
func dumpTableSQL(w *bufio.Writer, dbInst db.Database, config connection.ConnectionConfig, dbName, tableName string, includeData bool) error {
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
if _, err := w.WriteString("\n-- ----------------------------\n"); err != nil {
return err
}
if _, err := w.WriteString(fmt.Sprintf("-- Table: %s\n", qualifyTable(schemaName, pureTableName))); err != nil {
return err
}
if _, err := w.WriteString("-- ----------------------------\n\n"); err != nil {
return err
}
createSQL, err := dbInst.GetCreateStatement(schemaName, pureTableName)
if err != nil {
return err
}
if _, err := w.WriteString(ensureSQLTerminator(createSQL)); err != nil {
return err
}
if _, err := w.WriteString("\n\n"); err != nil {
return err
}
if !includeData {
return nil
}
qualified := qualifyTable(schemaName, pureTableName)
selectSQL := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.Type, qualified))
data, columns, err := dbInst.Query(selectSQL)
if err != nil {
return err
}
if len(data) == 0 {
if _, err := w.WriteString("-- (0 rows)\n"); err != nil {
return err
}
return nil
}
quotedCols := make([]string, 0, len(columns))
for _, c := range columns {
quotedCols = append(quotedCols, quoteIdentByType(config.Type, c))
}
quotedTable := quoteQualifiedIdentByType(config.Type, qualified)
for _, row := range data {
values := make([]string, 0, len(columns))
for _, c := range columns {
values = append(values, formatSQLValue(config.Type, row[c]))
}
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", quotedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", "))); err != nil {
return err
}
}
return nil
}
// ExportData exports provided data to a file
func (a *App) ExportData(data []map[string]interface{}, columns []string, defaultName string, format string) connection.QueryResult {
if defaultName == "" {
@@ -337,33 +613,101 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
if err := writeRowsToFile(f, data, columns, format); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
// ExportQuery exports by executing the provided SELECT query on backend side.
// This avoids frontend IPC payload limits when exporting very large/long-text columns (e.g. base64).
func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, query string, defaultName string, format string) connection.QueryResult {
query = strings.TrimSpace(query)
if query == "" {
return connection.QueryResult{Success: false, Message: "query required"}
}
if defaultName == "" {
defaultName = "export"
}
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: "Export Query Result",
DefaultFilename: fmt.Sprintf("%s.%s", defaultName, strings.ToLower(format)),
})
if err != nil || filename == "" {
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
query = sanitizeSQLForPgLike(runConfig.Type, query)
lowerQuery := strings.ToLower(strings.TrimSpace(query))
if !(strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "with")) {
return connection.QueryResult{Success: false, Message: "Only SELECT/WITH queries are supported"}
}
data, columns, err := dbInst.Query(query)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
if err := writeRowsToFile(f, data, columns, format); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string, format string) error {
format = strings.ToLower(strings.TrimSpace(format))
if f == nil {
return fmt.Errorf("file required")
}
format = strings.ToLower(format)
var csvWriter *csv.Writer
var jsonEncoder *json.Encoder
var isJsonFirstRow = true
isJsonFirstRow := true
switch format {
case "csv", "xlsx":
f.Write([]byte{0xEF, 0xBB, 0xBF})
if _, err := f.Write([]byte{0xEF, 0xBB, 0xBF}); err != nil {
return err
}
csvWriter = csv.NewWriter(f)
defer csvWriter.Flush()
if err := csvWriter.Write(columns); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
return err
}
case "json":
f.WriteString("[\n")
if _, err := f.WriteString("[\n"); err != nil {
return err
}
jsonEncoder = json.NewEncoder(f)
jsonEncoder.SetIndent(" ", " ")
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | ")); err != nil {
return err
}
seps := make([]string, len(columns))
for i := range seps {
seps[i] = "---"
}
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | ")); err != nil {
return err
}
default:
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
return fmt.Errorf("unsupported format: %s", format)
}
for _, rowMap := range data {
@@ -372,37 +716,51 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
val := rowMap[col]
if val == nil {
record[i] = "NULL"
} else {
s := fmt.Sprintf("%v", val)
if format == "md" {
s = strings.ReplaceAll(s, "|", "\\|")
s = strings.ReplaceAll(s, "\n", "<br>")
}
record[i] = s
continue
}
s := fmt.Sprintf("%v", val)
if format == "md" {
s = strings.ReplaceAll(s, "|", "\\|")
s = strings.ReplaceAll(s, "\n", "<br>")
}
record[i] = s
}
switch format {
case "csv", "xlsx":
if err := csvWriter.Write(record); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
return err
}
case "json":
if !isJsonFirstRow {
f.WriteString(",\n")
if _, err := f.WriteString(",\n"); err != nil {
return err
}
}
if err := jsonEncoder.Encode(rowMap); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
return err
}
isJsonFirstRow = false
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | ")); err != nil {
return err
}
}
}
if format == "csv" || format == "xlsx" {
csvWriter.Flush()
if err := csvWriter.Error(); err != nil {
return err
}
}
if format == "json" {
f.WriteString("\n]")
if _, err := f.WriteString("\n]"); err != nil {
return err
}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
return nil
}

View File

@@ -0,0 +1,481 @@
package app
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strings"
"sync"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/redis"
)
// Redis client cache
var (
redisCache = make(map[string]redis.RedisClient)
redisCacheMu sync.Mutex
)
// getRedisClient gets or creates a Redis client from cache
func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisClient, error) {
key := getRedisClientCacheKey(config)
shortKey := key
if len(shortKey) > 12 {
shortKey = shortKey[:12]
}
logger.Infof("获取 Redis 连接:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
redisCacheMu.Lock()
defer redisCacheMu.Unlock()
if client, ok := redisCache[key]; ok {
logger.Infof("命中 Redis 连接缓存开始检测可用性缓存Key=%s", shortKey)
if err := client.Ping(); err == nil {
logger.Infof("缓存 Redis 连接可用缓存Key=%s", shortKey)
return client, nil
} else {
logger.Error(err, "缓存 Redis 连接不可用准备重建缓存Key=%s", shortKey)
}
client.Close()
delete(redisCache, key)
}
logger.Infof("创建 Redis 客户端实例缓存Key=%s", shortKey)
client := redis.NewRedisClient()
if err := client.Connect(config); err != nil {
logger.Error(err, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
return nil, err
}
redisCache[key] = client
logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(config), shortKey)
return client, nil
}
func getRedisClientCacheKey(config connection.ConnectionConfig) string {
if !config.UseSSH {
config.SSH = connection.SSHConfig{}
}
b, _ := json.Marshal(config)
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:])
}
func formatRedisConnSummary(config connection.ConnectionConfig) string {
timeoutSeconds := config.Timeout
if timeoutSeconds <= 0 {
timeoutSeconds = 30
}
var b strings.Builder
b.WriteString("类型=redis 地址=")
b.WriteString(config.Host)
b.WriteString(":")
b.WriteString(string(rune(config.Port + '0')))
b.WriteString(" DB=")
b.WriteString(string(rune(config.RedisDB + '0')))
if config.UseSSH {
b.WriteString(" SSH=")
b.WriteString(config.SSH.Host)
b.WriteString(":")
b.WriteString(string(rune(config.SSH.Port + '0')))
b.WriteString(" 用户=")
b.WriteString(config.SSH.User)
}
return b.String()
}
// RedisConnect tests a Redis connection
func (a *App) RedisConnect(config connection.ConnectionConfig) connection.QueryResult {
config.Type = "redis"
_, err := a.getRedisClient(config)
if err != nil {
logger.Error(err, "RedisConnect 连接失败:%s", formatRedisConnSummary(config))
return connection.QueryResult{Success: false, Message: err.Error()}
}
logger.Infof("RedisConnect 连接成功:%s", formatRedisConnSummary(config))
return connection.QueryResult{Success: true, Message: "连接成功"}
}
// RedisTestConnection tests a Redis connection (alias for RedisConnect)
func (a *App) RedisTestConnection(config connection.ConnectionConfig) connection.QueryResult {
return a.RedisConnect(config)
}
// RedisScanKeys scans keys matching a pattern
func (a *App) RedisScanKeys(config connection.ConnectionConfig, pattern string, cursor uint64, count int64) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
result, err := client.ScanKeys(pattern, cursor, count)
if err != nil {
logger.Error(err, "RedisScanKeys 扫描失败pattern=%s", pattern)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: result}
}
// RedisGetValue gets the value of a key
func (a *App) RedisGetValue(config connection.ConnectionConfig, key string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
value, err := client.GetValue(key)
if err != nil {
logger.Error(err, "RedisGetValue 获取失败key=%s", key)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: value}
}
// RedisSetString sets a string value
func (a *App) RedisSetString(config connection.ConnectionConfig, key, value string, ttl int64) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.SetString(key, value, ttl); err != nil {
logger.Error(err, "RedisSetString 设置失败key=%s", key)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "设置成功"}
}
// RedisSetHashField sets a field in a hash
func (a *App) RedisSetHashField(config connection.ConnectionConfig, key, field, value string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.SetHashField(key, field, value); err != nil {
logger.Error(err, "RedisSetHashField 设置失败key=%s field=%s", key, field)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "设置成功"}
}
// RedisDeleteKeys deletes one or more keys
func (a *App) RedisDeleteKeys(config connection.ConnectionConfig, keys []string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
deleted, err := client.DeleteKeys(keys)
if err != nil {
logger.Error(err, "RedisDeleteKeys 删除失败keys=%v", keys)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: map[string]int64{"deleted": deleted}}
}
// RedisSetTTL sets the TTL of a key
func (a *App) RedisSetTTL(config connection.ConnectionConfig, key string, ttl int64) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.SetTTL(key, ttl); err != nil {
logger.Error(err, "RedisSetTTL 设置失败key=%s ttl=%d", key, ttl)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "设置成功"}
}
// RedisExecuteCommand executes a raw Redis command
func (a *App) RedisExecuteCommand(config connection.ConnectionConfig, command string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
// Parse command string into args
args := parseRedisCommand(command)
if len(args) == 0 {
return connection.QueryResult{Success: false, Message: "命令不能为空"}
}
result, err := client.ExecuteCommand(args)
if err != nil {
logger.Error(err, "RedisExecuteCommand 执行失败command=%s", command)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: result}
}
// parseRedisCommand parses a Redis command string into arguments
func parseRedisCommand(command string) []string {
command = strings.TrimSpace(command)
if command == "" {
return nil
}
var args []string
var current strings.Builder
inQuote := false
quoteChar := rune(0)
for _, ch := range command {
if inQuote {
if ch == quoteChar {
inQuote = false
args = append(args, current.String())
current.Reset()
} else {
current.WriteRune(ch)
}
} else {
if ch == '"' || ch == '\'' {
inQuote = true
quoteChar = ch
} else if ch == ' ' || ch == '\t' {
if current.Len() > 0 {
args = append(args, current.String())
current.Reset()
}
} else {
current.WriteRune(ch)
}
}
}
if current.Len() > 0 {
args = append(args, current.String())
}
return args
}
// RedisGetServerInfo returns server information
func (a *App) RedisGetServerInfo(config connection.ConnectionConfig) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
info, err := client.GetServerInfo()
if err != nil {
logger.Error(err, "RedisGetServerInfo 获取失败")
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: info}
}
// RedisGetDatabases returns information about all databases
func (a *App) RedisGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
dbs, err := client.GetDatabases()
if err != nil {
logger.Error(err, "RedisGetDatabases 获取失败")
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: dbs}
}
// RedisSelectDB selects a database
func (a *App) RedisSelectDB(config connection.ConnectionConfig, dbIndex int) connection.QueryResult {
config.Type = "redis"
config.RedisDB = dbIndex
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.SelectDB(dbIndex); err != nil {
logger.Error(err, "RedisSelectDB 切换失败db=%d", dbIndex)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "切换成功"}
}
// RedisRenameKey renames a key
func (a *App) RedisRenameKey(config connection.ConnectionConfig, oldKey, newKey string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.RenameKey(oldKey, newKey); err != nil {
logger.Error(err, "RedisRenameKey 重命名失败:%s -> %s", oldKey, newKey)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "重命名成功"}
}
// RedisDeleteHashField deletes fields from a hash
func (a *App) RedisDeleteHashField(config connection.ConnectionConfig, key string, fields []string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.DeleteHashField(key, fields...); err != nil {
logger.Error(err, "RedisDeleteHashField 删除失败key=%s fields=%v", key, fields)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "删除成功"}
}
// RedisListPush pushes values to a list
func (a *App) RedisListPush(config connection.ConnectionConfig, key string, values []string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.ListPush(key, values...); err != nil {
logger.Error(err, "RedisListPush 添加失败key=%s", key)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "添加成功"}
}
// RedisListSet sets a value at an index in a list
func (a *App) RedisListSet(config connection.ConnectionConfig, key string, index int64, value string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.ListSet(key, index, value); err != nil {
logger.Error(err, "RedisListSet 设置失败key=%s index=%d", key, index)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "设置成功"}
}
// RedisSetAdd adds members to a set
func (a *App) RedisSetAdd(config connection.ConnectionConfig, key string, members []string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.SetAdd(key, members...); err != nil {
logger.Error(err, "RedisSetAdd 添加失败key=%s", key)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "添加成功"}
}
// RedisSetRemove removes members from a set
func (a *App) RedisSetRemove(config connection.ConnectionConfig, key string, members []string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.SetRemove(key, members...); err != nil {
logger.Error(err, "RedisSetRemove 删除失败key=%s", key)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "删除成功"}
}
// RedisZSetAdd adds members to a sorted set
func (a *App) RedisZSetAdd(config connection.ConnectionConfig, key string, members []redis.ZSetMember) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.ZSetAdd(key, members...); err != nil {
logger.Error(err, "RedisZSetAdd 添加失败key=%s", key)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "添加成功"}
}
// RedisZSetRemove removes members from a sorted set
func (a *App) RedisZSetRemove(config connection.ConnectionConfig, key string, members []string) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.ZSetRemove(key, members...); err != nil {
logger.Error(err, "RedisZSetRemove 删除失败key=%s", key)
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "删除成功"}
}
// RedisFlushDB flushes the current database
func (a *App) RedisFlushDB(config connection.ConnectionConfig) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := client.FlushDB(); err != nil {
logger.Error(err, "RedisFlushDB 清空失败")
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "清空成功"}
}
// CloseAllRedisClients closes all cached Redis clients (called on shutdown)
func CloseAllRedisClients() {
redisCacheMu.Lock()
defer redisCacheMu.Unlock()
for key, client := range redisCache {
if client != nil {
client.Close()
logger.Infof("已关闭 Redis 连接:%s", key[:12])
}
}
redisCache = make(map[string]redis.RedisClient)
}

View File

@@ -1,11 +1,99 @@
package app
import (
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/sync"
"github.com/wailsapp/wails/v2/pkg/runtime"
)
// DataSync executes a data synchronization task
func (a *App) DataSync(config sync.SyncConfig) sync.SyncResult {
engine := sync.NewSyncEngine()
return engine.RunSync(config)
jobID := strings.TrimSpace(config.JobID)
if jobID == "" {
jobID = fmt.Sprintf("sync-%d", time.Now().UnixNano())
config.JobID = jobID
}
reporter := sync.Reporter{
OnLog: func(event sync.SyncLogEvent) {
runtime.EventsEmit(a.ctx, sync.EventSyncLog, event)
},
OnProgress: func(event sync.SyncProgressEvent) {
runtime.EventsEmit(a.ctx, sync.EventSyncProgress, event)
},
}
runtime.EventsEmit(a.ctx, sync.EventSyncStart, map[string]any{
"jobId": jobID,
"total": len(config.Tables),
})
engine := sync.NewSyncEngine(reporter)
res := engine.RunSync(config)
runtime.EventsEmit(a.ctx, sync.EventSyncDone, map[string]any{
"jobId": jobID,
"result": res,
})
return res
}
// DataSyncAnalyze analyzes differences between source and target for the given tables (dry-run).
func (a *App) DataSyncAnalyze(config sync.SyncConfig) connection.QueryResult {
jobID := strings.TrimSpace(config.JobID)
if jobID == "" {
jobID = fmt.Sprintf("analyze-%d", time.Now().UnixNano())
config.JobID = jobID
}
reporter := sync.Reporter{
OnLog: func(event sync.SyncLogEvent) {
runtime.EventsEmit(a.ctx, sync.EventSyncLog, event)
},
OnProgress: func(event sync.SyncProgressEvent) {
runtime.EventsEmit(a.ctx, sync.EventSyncProgress, event)
},
}
runtime.EventsEmit(a.ctx, sync.EventSyncStart, map[string]any{
"jobId": jobID,
"total": len(config.Tables),
"type": "analyze",
})
engine := sync.NewSyncEngine(reporter)
res := engine.Analyze(config)
runtime.EventsEmit(a.ctx, sync.EventSyncDone, map[string]any{
"jobId": jobID,
"result": res,
"type": "analyze",
})
if !res.Success {
return connection.QueryResult{Success: false, Message: res.Message, Data: res}
}
return connection.QueryResult{Success: true, Message: res.Message, Data: res}
}
// DataSyncPreview returns a limited preview of diff rows for one table.
func (a *App) DataSyncPreview(config sync.SyncConfig, tableName string, limit int) connection.QueryResult {
jobID := strings.TrimSpace(config.JobID)
if jobID == "" {
jobID = fmt.Sprintf("preview-%d", time.Now().UnixNano())
config.JobID = jobID
}
engine := sync.NewSyncEngine(sync.Reporter{})
preview, err := engine.Preview(config, tableName, limit)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "OK", Data: preview}
}

View File

@@ -0,0 +1,236 @@
package app
import (
"strings"
"unicode"
)
func sanitizeSQLForPgLike(dbType string, query string) string {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "postgres", "kingbase":
// 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。
// 这里做有限次数的迭代,直到输出不再变化。
out := query
for i := 0; i < 3; i++ {
fixed := fixBrokenDoubleDoubleQuotedIdent(out)
if fixed == out {
break
}
out = fixed
}
return out
default:
return query
}
}
// fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like:
//
// SELECT * FROM ""schema"".""table""
//
// which can be produced when a quoted identifier gets wrapped by quotes again.
//
// It is intentionally conservative:
// - only runs outside strings/comments/dollar-quoted blocks
// - does not touch valid escaped-quote sequences inside quoted identifiers (e.g. "a""b")
func fixBrokenDoubleDoubleQuotedIdent(query string) string {
if !strings.Contains(query, `""`) {
return query
}
var b strings.Builder
b.Grow(len(query))
inSingle := false
inDoubleIdent := false
inLineComment := false
inBlockComment := false
dollarTag := ""
for i := 0; i < len(query); i++ {
ch := query[i]
next := byte(0)
if i+1 < len(query) {
next = query[i+1]
}
if inLineComment {
b.WriteByte(ch)
if ch == '\n' {
inLineComment = false
}
continue
}
if inBlockComment {
b.WriteByte(ch)
if ch == '*' && next == '/' {
b.WriteByte('/')
i++
inBlockComment = false
}
continue
}
if dollarTag != "" {
if strings.HasPrefix(query[i:], dollarTag) {
b.WriteString(dollarTag)
i += len(dollarTag) - 1
dollarTag = ""
continue
}
b.WriteByte(ch)
continue
}
if inSingle {
b.WriteByte(ch)
if ch == '\'' {
// escaped single quote
if next == '\'' {
b.WriteByte('\'')
i++
continue
}
inSingle = false
}
continue
}
if inDoubleIdent {
b.WriteByte(ch)
if ch == '"' {
// escaped quote inside identifier
if next == '"' {
b.WriteByte('"')
i++
continue
}
inDoubleIdent = false
}
continue
}
// --- Outside of all string/comment blocks ---
if ch == '-' && next == '-' {
b.WriteByte(ch)
b.WriteByte('-')
i++
inLineComment = true
continue
}
if ch == '/' && next == '*' {
b.WriteByte(ch)
b.WriteByte('*')
i++
inBlockComment = true
continue
}
if ch == '\'' {
b.WriteByte(ch)
inSingle = true
continue
}
if ch == '$' {
if tag := parseDollarTag(query[i:]); tag != "" {
b.WriteString(tag)
i += len(tag) - 1
dollarTag = tag
continue
}
}
if ch == '"' {
// Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
// Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side).
if next == '"' {
if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok {
b.WriteString(replacement)
i = advance - 1
continue
}
}
b.WriteByte(ch)
inDoubleIdent = true
continue
}
b.WriteByte(ch)
}
return b.String()
}
func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) {
// start points at the first quote of a broken identifier, usually like:
// ""ident"" / ""ident""" / """"ident""""
if start < 0 || start+1 >= len(query) {
return "", 0, false
}
if query[start] != '"' || query[start+1] != '"' {
return "", 0, false
}
if start > 0 && query[start-1] == '"' {
return "", 0, false
}
runLen := 0
for start+runLen < len(query) && query[start+runLen] == '"' {
runLen++
}
if runLen < 2 || runLen%2 == 1 {
// Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes.
return "", 0, false
}
contentStart := start + runLen
j := contentStart
for j < len(query) {
if query[j] == '"' {
endRunLen := 0
for j+endRunLen < len(query) && query[j+endRunLen] == '"' {
endRunLen++
}
if endRunLen >= 2 {
content := strings.TrimSpace(query[contentStart:j])
if looksLikeIdentifierContent(content) {
return `"` + content + `"`, j + endRunLen, true
}
return "", 0, false
}
}
// Fast abort: identifier-like content should not span lines.
if query[j] == '\n' || query[j] == '\r' {
break
}
j++
}
return "", 0, false
}
func looksLikeIdentifierContent(s string) bool {
if strings.TrimSpace(s) == "" {
return false
}
for _, r := range s {
if r == '_' || r == '$' || r == '-' || unicode.IsLetter(r) || unicode.IsDigit(r) {
continue
}
return false
}
return true
}
func parseDollarTag(s string) string {
// Match: $tag$ where tag is [A-Za-z0-9_]* (can be empty => $$)
if len(s) < 2 || s[0] != '$' {
return ""
}
for i := 1; i < len(s); i++ {
c := s[i]
if c == '$' {
return s[:i+1]
}
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return ""
}
}
return ""
}

View File

@@ -0,0 +1,55 @@
package app
import "testing"
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes(t *testing.T) {
in := `SELECT * FROM ""ldf_server"".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithExtraQuotes(t *testing.T) {
in := `SELECT * FROM ""ldf_server""".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithQuadQuotes(t *testing.T) {
in := `SELECT * FROM """"ldf_server"""".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_DoesNotTouchEscapedQuotesInsideIdentifier(t *testing.T) {
in := `SELECT "a""b" FROM "t""x"`
out := sanitizeSQLForPgLike("postgres", in)
if out != in {
t.Fatalf("should keep valid escaped quotes inside identifier:\nIN: %s\nOUT: %s", in, out)
}
}
func TestSanitizeSQLForPgLike_DoesNotTouchDollarQuotedStrings(t *testing.T) {
in := "SELECT $$\"\"ldf_server\"\"$$, \"\"ldf_server\"\""
out := sanitizeSQLForPgLike("postgres", in)
want := "SELECT $$\"\"ldf_server\"\"$$, \"ldf_server\""
if out != want {
t.Fatalf("unexpected sanitize output for dollar quoted string:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
in := `SELECT * FROM ""ldf_server""`
out := sanitizeSQLForPgLike("mysql", in)
if out != in {
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
}
}

View File

@@ -19,9 +19,10 @@ type ConnectionConfig struct {
Database string `json:"database"`
UseSSH bool `json:"useSSH"`
SSH SSHConfig `json:"ssh"`
Driver string `json:"driver,omitempty"` // For custom connection
DSN string `json:"dsn,omitempty"` // For custom connection
Driver string `json:"driver,omitempty"` // For custom connection
DSN string `json:"dsn,omitempty"` // For custom connection
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
}
// QueryResult is the standard response format for Wails methods

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"strings"
@@ -57,6 +58,20 @@ func (c *CustomDB) Ping() error {
return c.conn.PingContext(ctx)
}
func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if c.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := c.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, error) {
if c.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -67,41 +82,18 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) {
if c.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := c.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (c *CustomDB) Exec(query string) (int64, error) {
@@ -136,13 +128,22 @@ func (c *CustomDB) GetTables(dbName string) ([]string, error) {
query = fmt.Sprintf("SHOW TABLES FROM `%s`", dbName)
}
} else if c.driver == "postgres" || c.driver == "kingbase" {
if dbName != "" && dbName != "public" {
query = fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", dbName)
query = `
SELECT table_schema AS schemaname, table_name AS tablename
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('pg_catalog', 'information_schema')`
if dbName != "" {
query += fmt.Sprintf(" AND table_schema = '%s'", dbName)
}
query += " ORDER BY table_schema, table_name"
} else if c.driver == "sqlite" {
query = "SELECT name FROM sqlite_master WHERE type='table'"
} else if c.driver == "oracle" || c.driver == "dm" {
query = "SELECT table_name FROM user_tables"
if dbName != "" {
query = fmt.Sprintf("SELECT owner, table_name FROM all_tables WHERE owner = '%s' ORDER BY table_name", strings.ToUpper(dbName))
}
}
// Fallback generic execution
@@ -153,6 +154,18 @@ func (c *CustomDB) GetTables(dbName string) ([]string, error) {
var tables []string
for _, row := range data {
if schema, okSchema := row["schemaname"]; okSchema {
if name, okName := row["tablename"]; okName {
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
continue
}
}
if owner, okOwner := row["OWNER"]; okOwner {
if name, okName := row["TABLE_NAME"]; okName {
tables = append(tables, fmt.Sprintf("%v.%v", owner, name))
continue
}
}
// iterate keys to find likely column
for k, v := range row {
if strings.Contains(strings.ToLower(k), "name") || strings.Contains(strings.ToLower(k), "table") {
@@ -235,7 +248,141 @@ func (c *CustomDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
}
func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
return fmt.Errorf("read-only mode for custom")
if c.conn == nil {
return fmt.Errorf("connection not open")
}
tx, err := c.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
driver := strings.ToLower(strings.TrimSpace(c.driver))
isMySQL := strings.Contains(driver, "mysql")
isPostgres := strings.Contains(driver, "postgres") || strings.Contains(driver, "kingbase") || strings.Contains(driver, "pg")
isOracle := strings.Contains(driver, "oracle") || strings.Contains(driver, "ora") || strings.Contains(driver, "dm") || strings.Contains(driver, "dameng")
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
if isMySQL {
n = strings.Trim(n, "`")
n = strings.ReplaceAll(n, "`", "``")
if n == "" {
return "``"
}
return "`" + n + "`"
}
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
placeholder := func(idx int) string {
if isPostgres {
return fmt.Sprintf("$%d", idx)
}
if isOracle {
return fmt.Sprintf(":%d", idx)
}
// MySQL / SQLite / default
return "?"
}
schema := ""
table := strings.TrimSpace(tableName)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.TrimSpace(parts[0])
table = strings.TrimSpace(parts[1])
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
} else {
qualifiedTable = quoteIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
idx := 0
for k, v := range pk {
idx++
wheres = append(wheres, fmt.Sprintf("%s = %s", quoteIdent(k), placeholder(idx)))
args = append(args, v)
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
idx := 0
for k, v := range update.Values {
idx++
sets = append(sets, fmt.Sprintf("%s = %s", quoteIdent(k), placeholder(idx)))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = %s", quoteIdent(k), placeholder(idx)))
args = append(args, v)
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
idx := 0
for k, v := range row {
idx++
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, placeholder(idx))
args = append(args, v)
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
return tx.Commit()
}
func (c *CustomDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
@@ -10,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -19,6 +21,7 @@ import (
type DamengDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
@@ -26,16 +29,6 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
// or dm://user:password@host:port
address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
if config.UseSSH {
// SSH logic similar to others, assumes port forwarding
_, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// DM driver likely uses standard net.Dial, so we might need a local listener
// or assume port forwarding is handled externally or implicitly via "tcp" override if driver allows.
// Similar to Oracle, we skip complex custom dialer injection for now.
}
}
escapedPassword := url.PathEscape(config.Password)
q := url.Values{}
if config.Database != "" {
@@ -55,7 +48,42 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
}
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
dsn := d.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("达梦数据库使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
d.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = d.getDSN(localConfig)
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = d.getDSN(config)
}
db, err := sql.Open("dm", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -69,6 +97,15 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
}
func (d *DamengDB) Close() error {
// Close SSH forwarder first if exists
if d.forwarder != nil {
if err := d.forwarder.Close(); err != nil {
logger.Warnf("关闭达梦数据库 SSH 端口转发失败:%v", err)
}
d.forwarder = nil
}
// Then close database connection
if d.conn != nil {
return d.conn.Close()
}
@@ -88,6 +125,20 @@ func (d *DamengDB) Ping() error {
return d.conn.PingContext(ctx)
}
func (d *DamengDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if d.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := d.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, error) {
if d.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -98,41 +149,18 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) {
if d.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := d.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (d *DamengDB) Exec(query string) (int64, error) {
@@ -166,7 +194,7 @@ func (d *DamengDB) GetDatabases() ([]string, error) {
}
func (d *DamengDB) GetTables(dbName string) ([]string, error) {
query := fmt.Sprintf("SELECT table_name FROM all_tables WHERE owner = '%s'", strings.ToUpper(dbName))
query := fmt.Sprintf("SELECT owner, table_name FROM all_tables WHERE owner = '%s' ORDER BY table_name", strings.ToUpper(dbName))
if dbName == "" {
query = "SELECT table_name FROM user_tables"
}
@@ -178,6 +206,14 @@ func (d *DamengDB) GetTables(dbName string) ([]string, error) {
var tables []string
for _, row := range data {
if dbName != "" {
if owner, okOwner := row["OWNER"]; okOwner {
if name, okName := row["TABLE_NAME"]; okName {
tables = append(tables, fmt.Sprintf("%v.%v", owner, name))
continue
}
}
}
if val, ok := row["TABLE_NAME"]; ok {
tables = append(tables, fmt.Sprintf("%v", val))
}
@@ -337,7 +373,117 @@ func (d *DamengDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
}
func (d *DamengDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
return fmt.Errorf("read-only mode implemented for Dameng so far")
if d.conn == nil {
return fmt.Errorf("connection not open")
}
tx, err := d.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
schema := ""
table := strings.TrimSpace(tableName)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.TrimSpace(parts[0])
table = strings.TrimSpace(parts[1])
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
} else {
qualifiedTable = quoteIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
idx := 0
for k, v := range pk {
idx++
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
idx := 0
for k, v := range update.Values {
idx++
sets = append(sets, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
idx := 0
for k, v := range row {
idx++
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, fmt.Sprintf(":%d", idx))
args = append(args, v)
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
return tx.Commit()
}
func (d *DamengDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {

View File

@@ -1,12 +1,16 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -16,6 +20,7 @@ import (
type KingbaseDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func quoteConnValue(v string) string {
@@ -57,20 +62,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
address := config.Host
port := config.Port
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// Kingbase/Postgres lib/pq allows custom dialer via "host" if using unix socket,
// but for custom network it's harder.
// Ideally we use a local forwarder.
// For now, we assume standard TCP or handle SSH externally.
// If we implement the net.Dial override for "kingbase" driver (which might use lib/pq internally),
// we might need to check if it supports "cloudsql" style or similar custom dialers.
// Similar to others, skipping SSH deep integration here for now.
_ = netName
}
}
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
quoteConnValue(address),
@@ -85,7 +76,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
dsn := k.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
k.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = k.getDSN(localConfig)
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = k.getDSN(config)
}
// Open using "kingbase" driver
db, err := sql.Open("kingbase", dsn)
if err != nil {
@@ -100,6 +126,15 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
}
func (k *KingbaseDB) Close() error {
// Close SSH forwarder first if exists
if k.forwarder != nil {
if err := k.forwarder.Close(); err != nil {
logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
}
k.forwarder = nil
}
// Then close database connection
if k.conn != nil {
return k.conn.Close()
}
@@ -119,6 +154,20 @@ func (k *KingbaseDB) Ping() error {
return k.conn.PingContext(ctx)
}
func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if k.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := k.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
if k.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -129,41 +178,18 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
if k.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := k.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (k *KingbaseDB) Exec(query string) (int64, error) {
@@ -193,15 +219,14 @@ func (k *KingbaseDB) GetDatabases() ([]string, error) {
}
func (k *KingbaseDB) GetTables(dbName string) ([]string, error) {
// Usually restricted to current database connection in PG/Kingbase
// dbName param is often Schema in PG context, or ignored if we are connected to a specific DB.
// But in PG, cross-database queries are not standard without dblink.
// We assume dbName here might mean Schema (public, etc.)
query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
if dbName != "" && dbName != "public" {
query = fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", dbName)
}
// Kingbase: tables are scoped by the current DB connection; include schema to avoid search_path issues.
query := `
SELECT table_schema AS schemaname, table_name AS tablename
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name`
data, _, err := k.Query(query)
if err != nil {
@@ -210,6 +235,12 @@ func (k *KingbaseDB) GetTables(dbName string) ([]string, error) {
var tables []string
for _, row := range data {
schema, okSchema := row["schemaname"]
name, okName := row["tablename"]
if okSchema && okName {
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
continue
}
if val, ok := row["table_name"]; ok {
tables = append(tables, fmt.Sprintf("%v", val))
}
@@ -226,15 +257,84 @@ func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error
}
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
schema := "public"
if dbName != "" {
schema = dbName
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = '%s' AND table_name = '%s'
ORDER BY ordinal_position`, schema, tableName)
// 如果仍然没有 schema,使用 current_schema()
// 这样可以自动匹配当前连接的 search_path
if schema == "" {
return k.getColumnsWithCurrentSchema(table)
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
// 移除前后的双引号(如果存在)
s = strings.Trim(s, "\"")
// 转义单引号
return strings.ReplaceAll(s, "'", "''")
}
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = '%s' AND table_name = '%s'
ORDER BY ordinal_position`, esc(schema), esc(table))
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
}
if row["column_default"] != nil {
def := fmt.Sprintf("%v", row["column_default"])
col.Default = &def
}
columns = append(columns, col)
}
return columns, nil
}
// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 使用 current_schema() 获取当前schema
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = '%s'
ORDER BY ordinal_position`, esc(table))
data, _, err := k.Query(query)
if err != nil {
@@ -260,32 +360,76 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
}
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
// Postgres/Kingbase index query
query := fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = '%s'
`, tableName, "public") // Default to public if dbName (schema) not clear.
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
if dbName != "" {
// Update query to use dbName as schema
query = strings.Replace(query, "'public'", fmt.Sprintf("'%s'", dbName), 1)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = '%s'
`, esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = current_schema()
`, esc(table))
}
data, _, err := k.Query(query)
@@ -314,27 +458,67 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
}
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
schema := "public"
if dbName != "" {
schema = dbName
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
query := fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
tableName, schema)
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {
@@ -356,9 +540,43 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
}
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
query := fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s'`, tableName)
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果指定了schema,也加上schema条件
var query string
if schema != "" {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {
@@ -379,18 +597,127 @@ func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.Trigger
}
func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
return fmt.Errorf("read-only mode implemented for Kingbase so far")
if k.conn == nil {
return fmt.Errorf("connection not open")
}
tx, err := k.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
schema := ""
table := strings.TrimSpace(tableName)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.TrimSpace(parts[0])
table = strings.TrimSpace(parts[1])
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
} else {
qualifiedTable = quoteIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
idx := 0
for k, v := range pk {
idx++
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
idx := 0
for k, v := range update.Values {
idx++
sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
idx := 0
for k, v := range row {
idx++
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, v)
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
return tx.Commit()
}
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
schema := "public"
if dbName != "" {
schema = dbName
}
query := fmt.Sprintf(`SELECT table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema = '%s'`, schema)
// dbName 在本项目语义里是“数据库”schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
query := `
SELECT table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name, ordinal_position`
data, _, err := k.Query(query)
if err != nil {
@@ -399,8 +726,14 @@ func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinition
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
schema := fmt.Sprintf("%v", row["table_schema"])
table := fmt.Sprintf("%v", row["table_name"])
tableName := table
if strings.TrimSpace(schema) != "" {
tableName = fmt.Sprintf("%s.%s", schema, table)
}
col := connection.ColumnDefinitionWithTable{
TableName: fmt.Sprintf("%v", row["table_name"]),
TableName: tableName,
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
}

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"strings"
@@ -48,7 +49,7 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
}
m.conn = db
m.pingTimeout = getConnectTimeout(config)
// Force verification
if err := m.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
@@ -76,6 +77,20 @@ func (m *MySQLDB) Ping() error {
return m.conn.PingContext(ctx)
}
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -86,41 +101,18 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := m.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (m *MySQLDB) Exec(query string) (int64, error) {
@@ -155,12 +147,12 @@ func (m *MySQLDB) GetTables(dbName string) ([]string, error) {
if dbName != "" {
query = fmt.Sprintf("SHOW TABLES FROM `%s`", dbName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var tables []string
for _, row := range data {
for _, v := range row {
@@ -181,7 +173,7 @@ func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) {
if err != nil {
return "", err
}
if len(data) > 0 {
if val, ok := data[0]["Create Table"]; ok {
return fmt.Sprintf("%v", val), nil
@@ -211,12 +203,12 @@ func (m *MySQLDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefin
Extra: fmt.Sprintf("%v", row["Extra"]),
Comment: fmt.Sprintf("%v", row["Comment"]),
}
if row["Default"] != nil {
d := fmt.Sprintf("%v", row["Default"])
col.Default = &d
}
columns = append(columns, col)
}
return columns, nil
@@ -244,14 +236,14 @@ func (m *MySQLDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefini
}
}
seq := 0
if val, ok := row["Seq_in_index"]; ok {
seq := 0
if val, ok := row["Seq_in_index"]; ok {
if f, ok := val.(float64); ok {
seq = int(f)
} else if i, ok := val.(int64); ok {
seq = int(i)
}
}
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["Key_name"]),
@@ -326,27 +318,31 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
var args []interface{}
for k, v := range pk {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, v)
args = append(args, normalizeMySQLDateTimeValue(v))
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("delete error: %v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("删除未生效:未匹配到任何行")
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
args = append(args, v)
args = append(args, normalizeMySQLDateTimeValue(v))
}
if len(sets) == 0 {
continue
}
@@ -354,17 +350,21 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
var wheres []string
for k, v := range update.Keys {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, v)
args = append(args, normalizeMySQLDateTimeValue(v))
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("update error: %v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("更新未生效:未匹配到任何行")
}
}
// 3. Inserts
@@ -372,26 +372,105 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
var cols []string
var placeholders []string
var args []interface{}
for k, v := range row {
cols = append(cols, fmt.Sprintf("`%s`", k))
placeholders = append(placeholders, "?")
args = append(args, v)
args = append(args, normalizeMySQLDateTimeValue(v))
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("insert error: %v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
}
return tx.Commit()
}
func normalizeMySQLDateTimeValue(value interface{}) interface{} {
text, ok := value.(string)
if !ok {
return value
}
raw := strings.TrimSpace(text)
if raw == "" {
return value
}
cleaned := strings.ReplaceAll(raw, "+ ", "+")
cleaned = strings.ReplaceAll(cleaned, "- ", "-")
if len(cleaned) >= 19 && cleaned[10] == 'T' {
if strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned) {
if t, err := time.Parse(time.RFC3339Nano, cleaned); err == nil {
return formatMySQLDateTime(t)
}
if t, err := time.Parse(time.RFC3339, cleaned); err == nil {
return formatMySQLDateTime(t)
}
}
return strings.Replace(cleaned, "T", " ", 1)
}
if strings.Contains(cleaned, " ") && (strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned)) {
candidate := strings.Replace(cleaned, " ", "T", 1)
if t, err := time.Parse(time.RFC3339Nano, candidate); err == nil {
return formatMySQLDateTime(t)
}
if t, err := time.Parse(time.RFC3339, candidate); err == nil {
return formatMySQLDateTime(t)
}
}
return value
}
func hasTimezoneOffset(text string) bool {
pos := strings.LastIndexAny(text, "+-")
if pos < 0 || pos < 10 || pos+1 >= len(text) {
return false
}
offset := text[pos+1:]
if len(offset) == 5 && offset[2] == ':' {
return isAllDigits(offset[:2]) && isAllDigits(offset[3:])
}
if len(offset) == 4 {
return isAllDigits(offset)
}
return false
}
func isAllDigits(text string) bool {
if text == "" {
return false
}
for _, r := range text {
if r < '0' || r > '9' {
return false
}
}
return true
}
func formatMySQLDateTime(t time.Time) string {
base := t.Format("2006-01-02 15:04:05")
nanos := t.Nanosecond()
if nanos == 0 {
return base
}
micro := nanos / 1000
return fmt.Sprintf("%s.%06d", base, micro)
}
func (m *MySQLDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
if dbName == "" {

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
@@ -10,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -19,6 +21,7 @@ import (
type OracleDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
@@ -28,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
database = config.User // Default to user service/schema if empty?
}
if config.UseSSH {
_, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// Oracle driver might not support custom dialer via DSN easily without extra config
// But go-ora v2 supports some advanced options.
// For simplicity, we assume standard TCP or we might need a workaround for SSH.
// go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open.
// But for now, let's just use the address.
// SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...).
// We might need to forward a local port if using SSH.
// Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...",
// we need to see if go-ora supports custom networks.
// Checking go-ora docs (simulated): It supports "unix" and "tcp".
// We might need to map the custom network to a local proxy.
// For now, we will assume direct connection or handle SSH separately later.
// We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial.
// Note: go-ora connection string: oracle://user:pass@host:port/service
// It parses host/port. It doesn't easily take a custom "network" parameter in URL.
// We will proceed with standard TCP string.
}
}
u := &url.URL{
Scheme: "oracle",
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
@@ -61,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
dsn := o.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
o.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = o.getDSN(localConfig)
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = o.getDSN(config)
}
db, err := sql.Open("oracle", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -75,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
}
func (o *OracleDB) Close() error {
// Close SSH forwarder first if exists
if o.forwarder != nil {
if err := o.forwarder.Close(); err != nil {
logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err)
}
o.forwarder = nil
}
// Then close database connection
if o.conn != nil {
return o.conn.Close()
}
@@ -94,6 +119,20 @@ func (o *OracleDB) Ping() error {
return o.conn.PingContext(ctx)
}
func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if o.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := o.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, error) {
if o.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -104,41 +143,18 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
if o.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := o.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (o *OracleDB) Exec(query string) (int64, error) {
@@ -171,7 +187,7 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) {
// dbName is Schema/Owner
query := "SELECT table_name FROM user_tables"
if dbName != "" {
query = fmt.Sprintf("SELECT table_name FROM all_tables WHERE owner = '%s'", strings.ToUpper(dbName))
query = fmt.Sprintf("SELECT owner, table_name FROM all_tables WHERE owner = '%s' ORDER BY table_name", strings.ToUpper(dbName))
}
data, _, err := o.Query(query)
@@ -181,6 +197,14 @@ func (o *OracleDB) GetTables(dbName string) ([]string, error) {
var tables []string
for _, row := range data {
if dbName != "" {
if owner, okOwner := row["OWNER"]; okOwner {
if name, okName := row["TABLE_NAME"]; okName {
tables = append(tables, fmt.Sprintf("%v.%v", owner, name))
continue
}
}
}
if val, ok := row["TABLE_NAME"]; ok {
tables = append(tables, fmt.Sprintf("%v", val))
}
@@ -339,8 +363,117 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
}
func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
// TODO: Implement batch application for Oracle using correct syntax
return fmt.Errorf("read-only mode implemented for Oracle so far")
if o.conn == nil {
return fmt.Errorf("connection not open")
}
tx, err := o.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
schema := ""
table := strings.TrimSpace(tableName)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.TrimSpace(parts[0])
table = strings.TrimSpace(parts[1])
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
} else {
qualifiedTable = quoteIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
idx := 0
for k, v := range pk {
idx++
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
idx := 0
for k, v := range update.Values {
idx++
sets = append(sets, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
idx := 0
for k, v := range row {
idx++
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, fmt.Sprintf(":%d", idx))
args = append(args, v)
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
return tx.Commit()
}
func (o *OracleDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {

View File

@@ -1,24 +1,31 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/lib/pq"
)
type PostgresDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
// postgres://user:password@host:port/dbname?sslmode=disable
dbname := config.Database
@@ -41,14 +48,49 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
}
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
dsn := p.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
p.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false // Disable SSH flag for DSN generation
dsn = p.getDSN(localConfig)
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = p.getDSN(config)
}
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
}
p.conn = db
p.pingTimeout = getConnectTimeout(config)
// Force verification
if err := p.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
@@ -56,7 +98,17 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
return nil
}
func (p *PostgresDB) Close() error {
// Close SSH forwarder first if exists
if p.forwarder != nil {
if err := p.forwarder.Close(); err != nil {
logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err)
}
p.forwarder = nil
}
// Then close database connection
if p.conn != nil {
return p.conn.Close()
}
@@ -76,52 +128,42 @@ func (p *PostgresDB) Ping() error {
return p.conn.PingContext(ctx)
}
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if p.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := p.conn.Query(query)
rows, err := p.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
return scanRows(rows)
}
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
if p.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := p.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
resultData = append(resultData, entry)
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
if p.conn == nil {
return 0, fmt.Errorf("connection not open")
}
return resultData, columns, nil
res, err := p.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (p *PostgresDB) Exec(query string) (int64, error) {
@@ -150,16 +192,22 @@ func (p *PostgresDB) GetDatabases() ([]string, error) {
}
func (p *PostgresDB) GetTables(dbName string) ([]string, error) {
query := "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg_%' ORDER BY schemaname, tablename"
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var tables []string
for _, row := range data {
if val, ok := row["tablename"]; ok {
tables = append(tables, fmt.Sprintf("%v", val))
schema, okSchema := row["schemaname"]
name, okName := row["tablename"]
if okSchema && okName {
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
continue
}
if okName {
tables = append(tables, fmt.Sprintf("%v", name))
}
}
return tables, nil
@@ -170,21 +218,420 @@ func (p *PostgresDB) GetCreateStatement(dbName, tableName string) (string, error
}
func (p *PostgresDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return []connection.ColumnDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
col_description(a.attrelid, a.attnum) AS comment,
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_attribute a ON a.attrelid = c.oid
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
LEFT JOIN (
SELECT i.indrelid, a3.attname
FROM pg_index i
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
WHERE i.indisprimary
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
WHERE c.relkind IN ('r', 'p')
AND n.nspname = '%s'
AND c.relname = '%s'
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum`, esc(schema), esc(table))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
Key: fmt.Sprintf("%v", row["column_key"]),
Extra: "",
Comment: "",
}
if v, ok := row["comment"]; ok && v != nil {
col.Comment = fmt.Sprintf("%v", v)
}
if v, ok := row["column_default"]; ok && v != nil {
def := fmt.Sprintf("%v", v)
col.Default = &def
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
col.Extra = "auto_increment"
}
}
columns = append(columns, col)
}
return columns, nil
}
func (p *PostgresDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return []connection.IndexDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT
i.relname AS index_name,
a.attname AS column_name,
ix.indisunique AS is_unique,
x.ordinality AS seq_in_index,
am.amname AS index_type
FROM pg_class t
JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN pg_index ix ON t.oid = ix.indrelid
JOIN pg_class i ON i.oid = ix.indexrelid
JOIN pg_am am ON i.relam = am.oid
JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
WHERE t.relkind IN ('r', 'p')
AND t.relname = '%s'
AND n.nspname = '%s'
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
parseBool := func(v interface{}) bool {
switch val := v.(type) {
case bool:
return val
case string:
s := strings.ToLower(strings.TrimSpace(val))
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
default:
s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
}
}
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
// best effort
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
var indexes []connection.IndexDefinition
for _, row := range data {
isUnique := false
if v, ok := row["is_unique"]; ok && v != nil {
isUnique = parseBool(v)
}
nonUnique := 1
if isUnique {
nonUnique = 0
}
seq := 0
if v, ok := row["seq_in_index"]; ok && v != nil {
seq = parseInt(v)
}
indexType := ""
if v, ok := row["index_type"]; ok && v != nil {
indexType = strings.ToUpper(fmt.Sprintf("%v", v))
}
if indexType == "" {
indexType = "BTREE"
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["index_name"]),
ColumnName: fmt.Sprintf("%v", row["column_name"]),
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: indexType,
}
indexes = append(indexes, idx)
}
return indexes, nil
}
func (p *PostgresDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT
tc.constraint_name AS constraint_name,
kcu.column_name AS column_name,
ccu.table_schema AS foreign_table_schema,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name = '%s'
AND tc.table_schema = '%s'
ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
refSchema := ""
if v, ok := row["foreign_table_schema"]; ok && v != nil {
refSchema = fmt.Sprintf("%v", v)
}
refTable := fmt.Sprintf("%v", row["foreign_table_name"])
refTableName := refTable
if strings.TrimSpace(refSchema) != "" {
refTableName = fmt.Sprintf("%s.%s", refSchema, refTable)
}
fk := connection.ForeignKeyDefinition{
Name: fmt.Sprintf("%v", row["constraint_name"]),
ColumnName: fmt.Sprintf("%v", row["column_name"]),
RefTableName: refTableName,
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
}
fks = append(fks, fk)
}
return fks, nil
}
func (p *PostgresDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT trigger_name, action_timing, event_manipulation, action_statement
FROM information_schema.triggers
WHERE event_object_table = '%s'
AND event_object_schema = '%s'
ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
trig := connection.TriggerDefinition{
Name: fmt.Sprintf("%v", row["trigger_name"]),
Timing: fmt.Sprintf("%v", row["action_timing"]),
Event: fmt.Sprintf("%v", row["event_manipulation"]),
Statement: fmt.Sprintf("%v", row["action_statement"]),
}
triggers = append(triggers, trig)
}
return triggers, nil
}
func (p *PostgresDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return []connection.ColumnDefinitionWithTable{}, nil
query := `
SELECT table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name, ordinal_position`
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
schema := fmt.Sprintf("%v", row["table_schema"])
table := fmt.Sprintf("%v", row["table_name"])
tableName := table
if strings.TrimSpace(schema) != "" {
tableName = fmt.Sprintf("%s.%s", schema, table)
}
col := connection.ColumnDefinitionWithTable{
TableName: tableName,
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
}
cols = append(cols, col)
}
return cols, nil
}
func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if p.conn == nil {
return fmt.Errorf("connection not open")
}
tx, err := p.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
schema := ""
table := strings.TrimSpace(tableName)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.TrimSpace(parts[0])
table = strings.TrimSpace(parts[1])
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
} else {
qualifiedTable = quoteIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
idx := 0
for k, v := range pk {
idx++
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
idx := 0
for k, v := range update.Values {
idx++
sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
idx := 0
for k, v := range row {
idx++
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, v)
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
return tx.Commit()
}

114
internal/db/query_value.go Normal file
View File

@@ -0,0 +1,114 @@
package db
import (
"encoding/hex"
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
// normalizeQueryValue normalizes driver-returned values for UI/JSON transport.
// 当前主要处理 []byte如果是可读文本则转为 string否则转为十六进制字符串避免前端出现“空白值”。
func normalizeQueryValue(v interface{}) interface{} {
return normalizeQueryValueWithDBType(v, "")
}
func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} {
if b, ok := v.([]byte); ok {
return bytesToDisplayValue(b, databaseTypeName)
}
return v
}
func bytesToDisplayValue(b []byte, databaseTypeName string) interface{} {
if b == nil {
return nil
}
if len(b) == 0 {
return ""
}
dbType := strings.ToUpper(strings.TrimSpace(databaseTypeName))
if isBitLikeDBType(dbType) {
if u, ok := bytesToUint64(b); ok {
// JS number precision is limited; keep large bitmasks as string.
const maxSafeInteger = 9007199254740991 // 2^53 - 1
if u <= maxSafeInteger {
return int64(u)
}
return fmt.Sprintf("%d", u)
}
}
if utf8.Valid(b) {
s := string(b)
if isMostlyPrintable(s) {
return s
}
}
// Fallback: some drivers return BIT(1) as []byte{0} / []byte{1} without type info.
if dbType == "" && len(b) == 1 && (b[0] == 0 || b[0] == 1) {
return int64(b[0])
}
return bytesToReadableString(b)
}
func bytesToReadableString(b []byte) interface{} {
if b == nil {
return nil
}
if len(b) == 0 {
return ""
}
return "0x" + hex.EncodeToString(b)
}
func isBitLikeDBType(typeName string) bool {
if typeName == "" {
return false
}
switch typeName {
case "BIT", "VARBIT":
return true
default:
}
return strings.HasPrefix(typeName, "BIT")
}
func bytesToUint64(b []byte) (uint64, bool) {
if len(b) == 0 || len(b) > 8 {
return 0, false
}
var u uint64
for _, v := range b {
u = (u << 8) | uint64(v)
}
return u, true
}
func isMostlyPrintable(s string) bool {
if s == "" {
return true
}
total := 0
printable := 0
for _, r := range s {
total++
switch r {
case '\n', '\r', '\t':
printable++
continue
default:
}
if unicode.IsPrint(r) {
printable++
}
}
// 允许少量不可见字符,避免把正常文本误判为二进制。
return printable*100 >= total*90
}

View File

@@ -0,0 +1,44 @@
package db
import "testing"
func TestNormalizeQueryValueWithDBType_BitBytes(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte{0x00}, "BIT")
if v != int64(0) {
t.Fatalf("BIT 0x00 期望为 0实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x01}, "bit")
if v != int64(1) {
t.Fatalf("BIT 0x01 期望为 1实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x01, 0x02}, "BIT VARYING")
if v != int64(258) {
t.Fatalf("BIT 0x0102 期望为 258实际=%v(%T)", v, v)
}
}
func TestNormalizeQueryValueWithDBType_BitLargeAsString(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "BIT")
if s, ok := v.(string); !ok || s != "18446744073709551615" {
t.Fatalf("BIT 0xffffffffffffffff 期望为 string(18446744073709551615),实际=%v(%T)", v, v)
}
}
func TestNormalizeQueryValueWithDBType_ByteFallbacks(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte("abc"), "")
if v != "abc" {
t.Fatalf("文本 []byte 期望返回 string实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x00}, "")
if v != int64(0) {
t.Fatalf("未知类型 0x00 期望返回 0实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0xff}, "")
if v != "0xff" {
t.Fatalf("未知类型 0xff 期望返回 0xff实际=%v(%T)", v, v)
}
}

46
internal/db/scan_rows.go Normal file
View File

@@ -0,0 +1,46 @@
package db
import (
"database/sql"
)
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
colTypes, err := rows.ColumnTypes()
if err != nil || len(colTypes) != len(columns) {
colTypes = nil
}
resultData := make([]map[string]interface{}, 0)
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{}, len(columns))
for i, col := range columns {
dbTypeName := ""
if colTypes != nil && i < len(colTypes) && colTypes[i] != nil {
dbTypeName = colTypes[i].DatabaseTypeName()
}
entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName)
}
resultData = append(resultData, entry)
}
if err := rows.Err(); err != nil {
return resultData, columns, err
}
return resultData, columns, nil
}

View File

@@ -1,8 +1,10 @@
package db
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
@@ -17,14 +19,14 @@ type SQLiteDB struct {
}
func (s *SQLiteDB) Connect(config connection.ConnectionConfig) error {
dsn := config.Host
dsn := config.Host
db, err := sql.Open("sqlite", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
}
s.conn = db
s.pingTimeout = getConnectTimeout(config)
// Force verification
if err := s.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
@@ -52,6 +54,20 @@ func (s *SQLiteDB) Ping() error {
return s.conn.PingContext(ctx)
}
func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := s.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -62,41 +78,18 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) {
if s.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := s.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (s *SQLiteDB) Exec(query string) (int64, error) {
@@ -120,7 +113,7 @@ func (s *SQLiteDB) GetTables(dbName string) ([]string, error) {
if err != nil {
return nil, err
}
var tables []string
for _, row := range data {
if val, ok := row["name"]; ok {
@@ -145,21 +138,443 @@ func (s *SQLiteDB) GetCreateStatement(dbName, tableName string) (string, error)
}
func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return []connection.ColumnDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
// cid, name, type, notnull, dflt_value, pk
data, _, err := s.Query(fmt.Sprintf("PRAGMA table_info('%s')", esc(table)))
if err != nil {
return nil, err
}
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
getStr := func(row map[string]interface{}, key string) string {
if v, ok := row[key]; ok && v != nil {
return fmt.Sprintf("%v", v)
}
if v, ok := row[strings.ToUpper(key)]; ok && v != nil {
return fmt.Sprintf("%v", v)
}
return ""
}
var columns []connection.ColumnDefinition
for _, row := range data {
notnull := 0
if v, ok := row["notnull"]; ok && v != nil {
notnull = parseInt(v)
} else if v, ok := row["NOTNULL"]; ok && v != nil {
notnull = parseInt(v)
}
pk := 0
if v, ok := row["pk"]; ok && v != nil {
pk = parseInt(v)
} else if v, ok := row["PK"]; ok && v != nil {
pk = parseInt(v)
}
nullable := "YES"
if notnull == 1 {
nullable = "NO"
}
key := ""
if pk == 1 {
key = "PRI"
}
col := connection.ColumnDefinition{
Name: getStr(row, "name"),
Type: getStr(row, "type"),
Nullable: nullable,
Key: key,
Extra: "",
Comment: "",
}
if v, ok := row["dflt_value"]; ok && v != nil {
def := fmt.Sprintf("%v", v)
col.Default = &def
} else if v, ok := row["DFLT_VALUE"]; ok && v != nil {
def := fmt.Sprintf("%v", v)
col.Default = &def
}
columns = append(columns, col)
}
return columns, nil
}
func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return []connection.IndexDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
data, _, err := s.Query(fmt.Sprintf("PRAGMA index_list('%s')", esc(table)))
if err != nil {
return nil, err
}
var indexes []connection.IndexDefinition
for _, row := range data {
indexName := ""
if v, ok := row["name"]; ok && v != nil {
indexName = fmt.Sprintf("%v", v)
} else if v, ok := row["NAME"]; ok && v != nil {
indexName = fmt.Sprintf("%v", v)
}
if strings.TrimSpace(indexName) == "" {
continue
}
unique := 0
if v, ok := row["unique"]; ok && v != nil {
unique = parseInt(v)
} else if v, ok := row["UNIQUE"]; ok && v != nil {
unique = parseInt(v)
}
nonUnique := 1
if unique == 1 {
nonUnique = 0
}
cols, _, err := s.Query(fmt.Sprintf("PRAGMA index_info('%s')", esc(indexName)))
if err != nil {
// skip broken index
continue
}
for _, c := range cols {
colName := ""
if v, ok := c["name"]; ok && v != nil {
colName = fmt.Sprintf("%v", v)
} else if v, ok := c["NAME"]; ok && v != nil {
colName = fmt.Sprintf("%v", v)
}
if strings.TrimSpace(colName) == "" {
continue
}
seq := 0
if v, ok := c["seqno"]; ok && v != nil {
seq = parseInt(v) + 1
} else if v, ok := c["SEQNO"]; ok && v != nil {
seq = parseInt(v) + 1
}
indexes = append(indexes, connection.IndexDefinition{
Name: indexName,
ColumnName: colName,
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: "BTREE",
})
}
}
return indexes, nil
}
func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
data, _, err := s.Query(fmt.Sprintf("PRAGMA foreign_key_list('%s')", esc(table)))
if err != nil {
return nil, err
}
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
id := 0
if v, ok := row["id"]; ok && v != nil {
id = parseInt(v)
} else if v, ok := row["ID"]; ok && v != nil {
id = parseInt(v)
}
refTable := ""
if v, ok := row["table"]; ok && v != nil {
refTable = fmt.Sprintf("%v", v)
} else if v, ok := row["TABLE"]; ok && v != nil {
refTable = fmt.Sprintf("%v", v)
}
fromCol := ""
if v, ok := row["from"]; ok && v != nil {
fromCol = fmt.Sprintf("%v", v)
} else if v, ok := row["FROM"]; ok && v != nil {
fromCol = fmt.Sprintf("%v", v)
}
toCol := ""
if v, ok := row["to"]; ok && v != nil {
toCol = fmt.Sprintf("%v", v)
} else if v, ok := row["TO"]; ok && v != nil {
toCol = fmt.Sprintf("%v", v)
}
name := fmt.Sprintf("fk_%s_%d", table, id)
fks = append(fks, connection.ForeignKeyDefinition{
Name: name,
ColumnName: fromCol,
RefTableName: refTable,
RefColumnName: toCol,
ConstraintName: name,
})
}
return fks, nil
}
func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
data, _, err := s.Query(fmt.Sprintf("SELECT name AS trigger_name, sql AS statement FROM sqlite_master WHERE type='trigger' AND tbl_name='%s' ORDER BY name", esc(table)))
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
name := fmt.Sprintf("%v", row["trigger_name"])
stmt := ""
if v, ok := row["statement"]; ok && v != nil {
stmt = fmt.Sprintf("%v", v)
}
upper := strings.ToUpper(stmt)
timing := ""
switch {
case strings.Contains(upper, " BEFORE "):
timing = "BEFORE"
case strings.Contains(upper, " AFTER "):
timing = "AFTER"
case strings.Contains(upper, " INSTEAD OF "):
timing = "INSTEAD OF"
}
event := ""
switch {
case strings.Contains(upper, " INSERT "):
event = "INSERT"
case strings.Contains(upper, " UPDATE "):
event = "UPDATE"
case strings.Contains(upper, " DELETE "):
event = "DELETE"
}
triggers = append(triggers, connection.TriggerDefinition{
Name: name,
Timing: timing,
Event: event,
Statement: stmt,
})
}
return triggers, nil
}
func (s *SQLiteDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if s.conn == nil {
return fmt.Errorf("connection not open")
}
tx, err := s.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
schema := ""
table := strings.TrimSpace(tableName)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.TrimSpace(parts[0])
table = strings.TrimSpace(parts[1])
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
} else {
qualifiedTable = quoteIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
for k, v := range pk {
wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, v)
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, v)
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
for k, v := range row {
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, "?")
args = append(args, v)
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
return tx.Commit()
}
func (s *SQLiteDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return []connection.ColumnDefinitionWithTable{}, nil
tables, err := s.GetTables(dbName)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, table := range tables {
// Skip internal tables
if strings.HasPrefix(strings.ToLower(table), "sqlite_") {
continue
}
columns, err := s.GetColumns("", table)
if err != nil {
continue
}
for _, col := range columns {
cols = append(cols, connection.ColumnDefinitionWithTable{
TableName: table,
Name: col.Name,
Type: col.Type,
})
}
}
return cols, nil
}

90
internal/redis/redis.go Normal file
View File

@@ -0,0 +1,90 @@
package redis
import "GoNavi-Wails/internal/connection"
// RedisValue represents a Redis value with its type and metadata
type RedisValue struct {
Type string `json:"type"` // string, hash, list, set, zset
TTL int64 `json:"ttl"` // TTL in seconds, -1 means no expiry, -2 means key doesn't exist
Value interface{} `json:"value"` // The actual value
Length int64 `json:"length"` // Length/size of the value
}
// RedisDBInfo represents information about a Redis database
type RedisDBInfo struct {
Index int `json:"index"` // Database index (0-15)
Keys int64 `json:"keys"` // Number of keys in this database
}
// RedisKeyInfo represents information about a Redis key
type RedisKeyInfo struct {
Key string `json:"key"`
Type string `json:"type"`
TTL int64 `json:"ttl"`
}
// RedisScanResult represents the result of a SCAN operation
type RedisScanResult struct {
Keys []RedisKeyInfo `json:"keys"`
Cursor uint64 `json:"cursor"`
}
// RedisClient defines the interface for Redis operations
type RedisClient interface {
// Connection management
Connect(config connection.ConnectionConfig) error
Close() error
Ping() error
// Key operations
ScanKeys(pattern string, cursor uint64, count int64) (*RedisScanResult, error)
GetKeyType(key string) (string, error)
GetTTL(key string) (int64, error)
SetTTL(key string, ttl int64) error
DeleteKeys(keys []string) (int64, error)
RenameKey(oldKey, newKey string) error
KeyExists(key string) (bool, error)
// Value operations
GetValue(key string) (*RedisValue, error)
// String operations
GetString(key string) (string, error)
SetString(key, value string, ttl int64) error
// Hash operations
GetHash(key string) (map[string]string, error)
SetHashField(key, field, value string) error
DeleteHashField(key string, fields ...string) error
// List operations
GetList(key string, start, stop int64) ([]string, error)
ListPush(key string, values ...string) error
ListSet(key string, index int64, value string) error
// Set operations
GetSet(key string) ([]string, error)
SetAdd(key string, members ...string) error
SetRemove(key string, members ...string) error
// Sorted Set operations
GetZSet(key string, start, stop int64) ([]ZSetMember, error)
ZSetAdd(key string, members ...ZSetMember) error
ZSetRemove(key string, members ...string) error
// Command execution
ExecuteCommand(args []string) (interface{}, error)
// Server information
GetServerInfo() (map[string]string, error)
GetDatabases() ([]RedisDBInfo, error)
SelectDB(index int) error
GetCurrentDB() int
FlushDB() error
}
// ZSetMember represents a member in a sorted set
type ZSetMember struct {
Member string `json:"member"`
Score float64 `json:"score"`
}

View File

@@ -0,0 +1,711 @@
package redis
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"github.com/redis/go-redis/v9"
)
// RedisClientImpl implements RedisClient using go-redis
type RedisClientImpl struct {
client *redis.Client
config connection.ConnectionConfig
currentDB int
forwarder *ssh.LocalForwarder
}
// NewRedisClient creates a new Redis client instance
func NewRedisClient() RedisClient {
return &RedisClientImpl{}
}
// Connect establishes a connection to Redis
func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
r.config = config
r.currentDB = config.RedisDB
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
// Handle SSH tunnel if enabled
if config.UseSSH {
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败: %w", err)
}
r.forwarder = forwarder
addr = forwarder.LocalAddr
logger.Infof("Redis 通过 SSH 隧道连接: %s -> %s:%d", addr, config.Host, config.Port)
}
opts := &redis.Options{
Addr: addr,
Password: config.Password,
DB: config.RedisDB,
DialTimeout: time.Duration(config.Timeout) * time.Second,
ReadTimeout: time.Duration(config.Timeout) * time.Second,
WriteTimeout: time.Duration(config.Timeout) * time.Second,
}
if opts.DialTimeout == 0 {
opts.DialTimeout = 30 * time.Second
opts.ReadTimeout = 30 * time.Second
opts.WriteTimeout = 30 * time.Second
}
r.client = redis.NewClient(opts)
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), opts.DialTimeout)
defer cancel()
if err := r.client.Ping(ctx).Err(); err != nil {
r.client.Close()
r.client = nil
return fmt.Errorf("Redis 连接失败: %w", err)
}
logger.Infof("Redis 连接成功: %s DB=%d", addr, config.RedisDB)
return nil
}
// Close closes the Redis connection
func (r *RedisClientImpl) Close() error {
if r.client != nil {
err := r.client.Close()
r.client = nil
return err
}
return nil
}
// Ping tests the connection
func (r *RedisClientImpl) Ping() error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return r.client.Ping(ctx).Err()
}
// ScanKeys scans keys matching a pattern
func (r *RedisClientImpl) ScanKeys(pattern string, cursor uint64, count int64) (*RedisScanResult, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if pattern == "" {
pattern = "*"
}
if count <= 0 {
count = 100
}
keys, nextCursor, err := r.client.Scan(ctx, cursor, pattern, count).Result()
if err != nil {
return nil, err
}
result := &RedisScanResult{
Keys: make([]RedisKeyInfo, 0, len(keys)),
Cursor: nextCursor,
}
// Get type and TTL for each key
pipe := r.client.Pipeline()
typeResults := make([]*redis.StatusCmd, len(keys))
ttlResults := make([]*redis.DurationCmd, len(keys))
for i, key := range keys {
typeResults[i] = pipe.Type(ctx, key)
ttlResults[i] = pipe.TTL(ctx, key)
}
_, err = pipe.Exec(ctx)
if err != nil && err != redis.Nil {
// Fallback: get info one by one
for _, key := range keys {
keyType, _ := r.GetKeyType(key)
ttl, _ := r.GetTTL(key)
result.Keys = append(result.Keys, RedisKeyInfo{
Key: key,
Type: keyType,
TTL: ttl,
})
}
return result, nil
}
for i, key := range keys {
keyType := typeResults[i].Val()
ttl := int64(ttlResults[i].Val().Seconds())
if ttlResults[i].Val() == -1 {
ttl = -1
} else if ttlResults[i].Val() == -2 {
ttl = -2
}
result.Keys = append(result.Keys, RedisKeyInfo{
Key: key,
Type: keyType,
TTL: ttl,
})
}
return result, nil
}
// GetKeyType returns the type of a key
func (r *RedisClientImpl) GetKeyType(key string) (string, error) {
if r.client == nil {
return "", fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return r.client.Type(ctx, key).Result()
}
// GetTTL returns the TTL of a key in seconds
func (r *RedisClientImpl) GetTTL(key string) (int64, error) {
if r.client == nil {
return 0, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ttl, err := r.client.TTL(ctx, key).Result()
if err != nil {
return 0, err
}
if ttl == -1 {
return -1, nil // No expiry
} else if ttl == -2 {
return -2, nil // Key doesn't exist
}
return int64(ttl.Seconds()), nil
}
// SetTTL sets the TTL of a key
func (r *RedisClientImpl) SetTTL(key string, ttl int64) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if ttl < 0 {
// Remove expiry
return r.client.Persist(ctx, key).Err()
}
return r.client.Expire(ctx, key, time.Duration(ttl)*time.Second).Err()
}
// DeleteKeys deletes one or more keys
func (r *RedisClientImpl) DeleteKeys(keys []string) (int64, error) {
if r.client == nil {
return 0, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return r.client.Del(ctx, keys...).Result()
}
// RenameKey renames a key
func (r *RedisClientImpl) RenameKey(oldKey, newKey string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return r.client.Rename(ctx, oldKey, newKey).Err()
}
// KeyExists checks if a key exists
func (r *RedisClientImpl) KeyExists(key string) (bool, error) {
if r.client == nil {
return false, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
n, err := r.client.Exists(ctx, key).Result()
return n > 0, err
}
// GetValue gets the value of a key with automatic type detection
func (r *RedisClientImpl) GetValue(key string) (*RedisValue, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
keyType, err := r.GetKeyType(key)
if err != nil {
return nil, err
}
ttl, _ := r.GetTTL(key)
result := &RedisValue{
Type: keyType,
TTL: ttl,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
switch keyType {
case "string":
val, err := r.client.Get(ctx, key).Result()
if err != nil {
return nil, err
}
result.Value = val
result.Length = int64(len(val))
case "hash":
val, err := r.client.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
result.Value = val
result.Length = int64(len(val))
case "list":
length, err := r.client.LLen(ctx, key).Result()
if err != nil {
return nil, err
}
// Get first 1000 items
limit := int64(1000)
if length < limit {
limit = length
}
val, err := r.client.LRange(ctx, key, 0, limit-1).Result()
if err != nil {
return nil, err
}
result.Value = val
result.Length = length
case "set":
length, err := r.client.SCard(ctx, key).Result()
if err != nil {
return nil, err
}
// Get members using SMembers (limited by Redis server)
members, err := r.client.SMembers(ctx, key).Result()
if err != nil {
return nil, err
}
result.Value = members
result.Length = length
case "zset":
length, err := r.client.ZCard(ctx, key).Result()
if err != nil {
return nil, err
}
// Get first 1000 members with scores
limit := int64(1000)
if length < limit {
limit = length
}
val, err := r.client.ZRangeWithScores(ctx, key, 0, limit-1).Result()
if err != nil {
return nil, err
}
members := make([]ZSetMember, len(val))
for i, z := range val {
members[i] = ZSetMember{
Member: z.Member.(string),
Score: z.Score,
}
}
result.Value = members
result.Length = length
default:
return nil, fmt.Errorf("不支持的 Redis 数据类型: %s", keyType)
}
return result, nil
}
// GetString gets a string value
func (r *RedisClientImpl) GetString(key string) (string, error) {
if r.client == nil {
return "", fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return r.client.Get(ctx, key).Result()
}
// SetString sets a string value with optional TTL
func (r *RedisClientImpl) SetString(key, value string, ttl int64) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var expiration time.Duration
if ttl > 0 {
expiration = time.Duration(ttl) * time.Second
}
return r.client.Set(ctx, key, value, expiration).Err()
}
// GetHash gets all fields of a hash
func (r *RedisClientImpl) GetHash(key string) (map[string]string, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return r.client.HGetAll(ctx, key).Result()
}
// SetHashField sets a field in a hash
func (r *RedisClientImpl) SetHashField(key, field, value string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return r.client.HSet(ctx, key, field, value).Err()
}
// DeleteHashField deletes fields from a hash
func (r *RedisClientImpl) DeleteHashField(key string, fields ...string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return r.client.HDel(ctx, key, fields...).Err()
}
// GetList gets a range of elements from a list
func (r *RedisClientImpl) GetList(key string, start, stop int64) ([]string, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return r.client.LRange(ctx, key, start, stop).Result()
}
// ListPush pushes values to the end of a list
func (r *RedisClientImpl) ListPush(key string, values ...string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
args := make([]interface{}, len(values))
for i, v := range values {
args[i] = v
}
return r.client.RPush(ctx, key, args...).Err()
}
// ListSet sets the value at an index in a list
func (r *RedisClientImpl) ListSet(key string, index int64, value string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return r.client.LSet(ctx, key, index, value).Err()
}
// GetSet gets all members of a set
func (r *RedisClientImpl) GetSet(key string) ([]string, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return r.client.SMembers(ctx, key).Result()
}
// SetAdd adds members to a set
func (r *RedisClientImpl) SetAdd(key string, members ...string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
args := make([]interface{}, len(members))
for i, m := range members {
args[i] = m
}
return r.client.SAdd(ctx, key, args...).Err()
}
// SetRemove removes members from a set
func (r *RedisClientImpl) SetRemove(key string, members ...string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
args := make([]interface{}, len(members))
for i, m := range members {
args[i] = m
}
return r.client.SRem(ctx, key, args...).Err()
}
// GetZSet gets members with scores from a sorted set
func (r *RedisClientImpl) GetZSet(key string, start, stop int64) ([]ZSetMember, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
val, err := r.client.ZRangeWithScores(ctx, key, start, stop).Result()
if err != nil {
return nil, err
}
members := make([]ZSetMember, len(val))
for i, z := range val {
members[i] = ZSetMember{
Member: z.Member.(string),
Score: z.Score,
}
}
return members, nil
}
// ZSetAdd adds members to a sorted set
func (r *RedisClientImpl) ZSetAdd(key string, members ...ZSetMember) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
zMembers := make([]redis.Z, len(members))
for i, m := range members {
zMembers[i] = redis.Z{
Score: m.Score,
Member: m.Member,
}
}
return r.client.ZAdd(ctx, key, zMembers...).Err()
}
// ZSetRemove removes members from a sorted set
func (r *RedisClientImpl) ZSetRemove(key string, members ...string) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
args := make([]interface{}, len(members))
for i, m := range members {
args[i] = m
}
return r.client.ZRem(ctx, key, args...).Err()
}
// ExecuteCommand executes a raw Redis command
func (r *RedisClientImpl) ExecuteCommand(args []string) (interface{}, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
if len(args) == 0 {
return nil, fmt.Errorf("命令不能为空")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Convert to []interface{}
cmdArgs := make([]interface{}, len(args))
for i, arg := range args {
cmdArgs[i] = arg
}
result, err := r.client.Do(ctx, cmdArgs...).Result()
if err != nil {
return nil, err
}
return formatCommandResult(result), nil
}
// formatCommandResult formats the command result for display
func formatCommandResult(result interface{}) interface{} {
switch v := result.(type) {
case []interface{}:
formatted := make([]interface{}, len(v))
for i, item := range v {
formatted[i] = formatCommandResult(item)
}
return formatted
case []byte:
return string(v)
default:
return v
}
}
// GetServerInfo returns server information
func (r *RedisClientImpl) GetServerInfo() (map[string]string, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
info, err := r.client.Info(ctx).Result()
if err != nil {
return nil, err
}
result := make(map[string]string)
lines := strings.Split(info, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
result[parts[0]] = parts[1]
}
}
return result, nil
}
// GetDatabases returns information about all databases
func (r *RedisClientImpl) GetDatabases() ([]RedisDBInfo, error) {
if r.client == nil {
return nil, fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Get keyspace info
info, err := r.client.Info(ctx, "keyspace").Result()
if err != nil {
return nil, err
}
// Parse keyspace info
dbMap := make(map[int]int64)
lines := strings.Split(info, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "db") {
// Format: db0:keys=123,expires=0,avg_ttl=0
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
dbIndex, err := strconv.Atoi(strings.TrimPrefix(parts[0], "db"))
if err != nil {
continue
}
// Parse keys count
kvPairs := strings.Split(parts[1], ",")
for _, kv := range kvPairs {
if strings.HasPrefix(kv, "keys=") {
keys, _ := strconv.ParseInt(strings.TrimPrefix(kv, "keys="), 10, 64)
dbMap[dbIndex] = keys
break
}
}
}
}
// Return all 16 databases (0-15)
result := make([]RedisDBInfo, 16)
for i := 0; i < 16; i++ {
result[i] = RedisDBInfo{
Index: i,
Keys: dbMap[i], // Will be 0 if not in map
}
}
return result, nil
}
// SelectDB selects a database
func (r *RedisClientImpl) SelectDB(index int) error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
if index < 0 || index > 15 {
return fmt.Errorf("数据库索引必须在 0-15 之间")
}
// Create new client with different DB
addr := fmt.Sprintf("%s:%d", r.config.Host, r.config.Port)
if r.forwarder != nil {
addr = r.forwarder.LocalAddr
}
opts := &redis.Options{
Addr: addr,
Password: r.config.Password,
DB: index,
DialTimeout: time.Duration(r.config.Timeout) * time.Second,
ReadTimeout: time.Duration(r.config.Timeout) * time.Second,
WriteTimeout: time.Duration(r.config.Timeout) * time.Second,
}
if opts.DialTimeout == 0 {
opts.DialTimeout = 30 * time.Second
opts.ReadTimeout = 30 * time.Second
opts.WriteTimeout = 30 * time.Second
}
newClient := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), opts.DialTimeout)
defer cancel()
if err := newClient.Ping(ctx).Err(); err != nil {
newClient.Close()
return fmt.Errorf("切换数据库失败: %w", err)
}
// Close old client and replace
r.client.Close()
r.client = newClient
r.currentDB = index
logger.Infof("Redis 切换到数据库: db%d", index)
return nil
}
// GetCurrentDB returns the current database index
func (r *RedisClientImpl) GetCurrentDB() int {
return r.currentDB
}
// FlushDB flushes the current database
func (r *RedisClientImpl) FlushDB() error {
if r.client == nil {
return fmt.Errorf("Redis 客户端未连接")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return r.client.FlushDB(ctx).Err()
}

View File

@@ -3,8 +3,10 @@ package ssh
import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"GoNavi-Wails/internal/connection"
@@ -110,3 +112,264 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
return netName, nil
}
// sshClientCache stores SSH clients to avoid creating multiple connections
var (
sshClientCache = make(map[string]*ssh.Client)
sshClientCacheMu sync.RWMutex
localForwarders = make(map[string]*LocalForwarder)
forwarderMu sync.RWMutex
)
// LocalForwarder represents a local port forwarder through SSH
type LocalForwarder struct {
LocalAddr string
RemoteAddr string
SSHClient *ssh.Client
listener net.Listener
closeChan chan struct{}
closeOnce sync.Once // 防止重复关闭
closed bool // 关闭状态标记
closedMu sync.RWMutex
}
// NewLocalForwarder creates a new local port forwarder
// It listens on a random local port and forwards all connections through SSH tunnel
func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
client, err := GetOrCreateSSHClient(sshConfig)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
// Listen on localhost with a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("创建本地监听器失败:%w", err)
}
localAddr := listener.Addr().String()
remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)
forwarder := &LocalForwarder{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
SSHClient: client,
listener: listener,
closeChan: make(chan struct{}),
}
// Start forwarding in background
go forwarder.forward()
logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr)
return forwarder, nil
}
// forward handles the port forwarding
func (f *LocalForwarder) forward() {
for {
localConn, err := f.listener.Accept()
if err != nil {
// Check if we're shutting down
select {
case <-f.closeChan:
return
default:
logger.Warnf("接受本地连接失败:%v", err)
// listener可能已关闭,退出循环
return
}
}
go f.handleConnection(localConn)
}
}
// handleConnection handles a single connection
func (f *LocalForwarder) handleConnection(localConn net.Conn) {
defer localConn.Close()
// Connect to remote through SSH with timeout
remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr)
if err != nil {
logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err)
return
}
defer remoteConn.Close()
// Bidirectional copy with error channel
errc := make(chan error, 2)
// Copy from local to remote
go func() {
_, err := io.Copy(remoteConn, localConn)
if err != nil {
logger.Warnf("本地->远程数据复制错误:%v", err)
}
errc <- err
}()
// Copy from remote to local
go func() {
_, err := io.Copy(localConn, remoteConn)
if err != nil {
logger.Warnf("远程->本地数据复制错误:%v", err)
}
errc <- err
}()
// Wait for BOTH goroutines to complete
<-errc
<-errc
}
// Close closes the forwarder (thread-safe, can be called multiple times)
func (f *LocalForwarder) Close() error {
var err error
f.closeOnce.Do(func() {
f.closedMu.Lock()
f.closed = true
f.closedMu.Unlock()
close(f.closeChan)
err = f.listener.Close()
if err != nil {
logger.Warnf("关闭端口转发监听器失败:%v", err)
}
})
return err
}
// IsClosed returns whether the forwarder is closed
func (f *LocalForwarder) IsClosed() bool {
f.closedMu.RLock()
defer f.closedMu.RUnlock()
return f.closed
}
// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one
func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
key := fmt.Sprintf("%s:%d:%s->%s:%d",
sshConfig.Host, sshConfig.Port, sshConfig.User,
remoteHost, remotePort)
forwarderMu.RLock()
forwarder, exists := localForwarders[key]
forwarderMu.RUnlock()
// Check if exists and is still valid
if exists && forwarder != nil && !forwarder.IsClosed() {
logger.Infof("复用已有端口转发:%s", key)
return forwarder, nil
}
// Remove stale forwarder from cache
if exists {
forwarderMu.Lock()
delete(localForwarders, key)
forwarderMu.Unlock()
}
forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort)
if err != nil {
return nil, err
}
forwarderMu.Lock()
localForwarders[key] = forwarder
forwarderMu.Unlock()
return forwarder, nil
}
// CloseAllForwarders closes all local forwarders
func CloseAllForwarders() {
forwarderMu.Lock()
defer forwarderMu.Unlock()
for key, forwarder := range localForwarders {
if forwarder != nil {
_ = forwarder.Close()
logger.Infof("已关闭端口转发:%s", key)
}
}
localForwarders = make(map[string]*LocalForwarder)
}
// getSSHClientCacheKey generates a unique cache key for SSH config
func getSSHClientCacheKey(config connection.SSHConfig) string {
return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User)
}
// GetOrCreateSSHClient returns a cached SSH client or creates a new one
func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) {
key := getSSHClientCacheKey(config)
sshClientCacheMu.RLock()
client, exists := sshClientCache[key]
sshClientCacheMu.RUnlock()
if exists && client != nil {
// Test if connection is still alive by creating a test session
session, err := client.NewSession()
if err == nil {
session.Close()
logger.Infof("复用已有 SSH 连接:%s", key)
return client, nil
}
// Connection is dead, remove from cache
logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err)
sshClientCacheMu.Lock()
delete(sshClientCache, key)
sshClientCacheMu.Unlock()
// Try to close the dead client
_ = client.Close()
}
// Create new SSH client
client, err := connectSSH(config)
if err != nil {
return nil, err
}
// Cache the client
sshClientCacheMu.Lock()
sshClientCache[key] = client
sshClientCacheMu.Unlock()
logger.Infof("已缓存 SSH 连接:%s", key)
return client, nil
}
// DialThroughSSH creates a connection through SSH tunnel
// This is a generic dialer that can be used by any database driver
func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) {
client, err := GetOrCreateSSHClient(config)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
conn, err := client.Dial(network, address)
if err != nil {
return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err)
}
logger.Infof("已通过 SSH 隧道连接到:%s", address)
return conn, nil
}
// CloseAllSSHClients closes all cached SSH clients
func CloseAllSSHClients() {
sshClientCacheMu.Lock()
defer sshClientCacheMu.Unlock()
for key, client := range sshClientCache {
if client != nil {
_ = client.Close()
logger.Infof("已关闭 SSH 连接:%s", key)
}
}
sshClientCache = make(map[string]*ssh.Client)
}

198
internal/sync/analyze.go Normal file
View File

@@ -0,0 +1,198 @@
package sync
import (
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"fmt"
"strings"
)
type TableDiffSummary struct {
Table string `json:"table"`
PKColumn string `json:"pkColumn,omitempty"`
CanSync bool `json:"canSync"`
Inserts int `json:"inserts"`
Updates int `json:"updates"`
Deletes int `json:"deletes"`
Same int `json:"same"`
Message string `json:"message,omitempty"`
HasSchema bool `json:"hasSchema,omitempty"`
}
type SyncAnalyzeResult struct {
Success bool `json:"success"`
Message string `json:"message"`
Tables []TableDiffSummary `json:"tables"`
}
func (s *SyncEngine) Analyze(config SyncConfig) SyncAnalyzeResult {
result := SyncAnalyzeResult{Success: true, Tables: []TableDiffSummary{}}
contentRaw := strings.ToLower(strings.TrimSpace(config.Content))
syncSchema := false
syncData := true
switch contentRaw {
case "", "data":
syncData = true
case "schema":
syncSchema = true
syncData = false
case "both":
syncSchema = true
syncData = true
default:
s.appendLog(config.JobID, nil, "warn", fmt.Sprintf("未知同步内容 %q已自动使用仅同步数据", config.Content))
syncData = true
}
totalTables := len(config.Tables)
s.progress(config.JobID, 0, totalTables, "", "差异分析开始")
sourceDB, err := db.NewDatabase(config.SourceConfig.Type)
if err != nil {
logger.Error(err, "初始化源数据库驱动失败:类型=%s", config.SourceConfig.Type)
return SyncAnalyzeResult{Success: false, Message: "初始化源数据库驱动失败: " + err.Error()}
}
targetDB, err := db.NewDatabase(config.TargetConfig.Type)
if err != nil {
logger.Error(err, "初始化目标数据库驱动失败:类型=%s", config.TargetConfig.Type)
return SyncAnalyzeResult{Success: false, Message: "初始化目标数据库驱动失败: " + err.Error()}
}
// Connect Source
if err := sourceDB.Connect(config.SourceConfig); err != nil {
logger.Error(err, "源数据库连接失败:%s", formatConnSummaryForSync(config.SourceConfig))
return SyncAnalyzeResult{Success: false, Message: "源数据库连接失败: " + err.Error()}
}
defer sourceDB.Close()
// Connect Target
if err := targetDB.Connect(config.TargetConfig); err != nil {
logger.Error(err, "目标数据库连接失败:%s", formatConnSummaryForSync(config.TargetConfig))
return SyncAnalyzeResult{Success: false, Message: "目标数据库连接失败: " + err.Error()}
}
defer targetDB.Close()
for i, tableName := range config.Tables {
func() {
s.progress(config.JobID, i, totalTables, tableName, fmt.Sprintf("分析表(%d/%d)", i+1, totalTables))
summary := TableDiffSummary{
Table: tableName,
CanSync: false,
Inserts: 0,
Updates: 0,
Deletes: 0,
Same: 0,
Message: "",
HasSchema: syncSchema,
}
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
sourceQueryTable := qualifiedNameForQuery(config.SourceConfig.Type, sourceSchema, sourceTable, tableName)
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
cols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
if err != nil {
summary.Message = "获取源表字段失败: " + err.Error()
result.Tables = append(result.Tables, summary)
return
}
if !syncData {
summary.CanSync = true
summary.Message = "仅同步结构,未执行数据差异分析"
result.Tables = append(result.Tables, summary)
return
}
pkCols := make([]string, 0, 2)
for _, c := range cols {
if c.Key == "PRI" || c.Key == "PK" {
pkCols = append(pkCols, c.Name)
}
}
if len(pkCols) == 0 {
summary.Message = "无主键,不支持数据对比/同步"
result.Tables = append(result.Tables, summary)
return
}
if len(pkCols) > 1 {
summary.Message = fmt.Sprintf("复合主键(%s暂不支持数据对比/同步", strings.Join(pkCols, ","))
result.Tables = append(result.Tables, summary)
return
}
summary.PKColumn = pkCols[0]
// Query data for diff
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.SourceConfig.Type, sourceQueryTable)))
if err != nil {
summary.Message = "读取源表失败: " + err.Error()
result.Tables = append(result.Tables, summary)
return
}
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable)))
if err != nil {
summary.Message = "读取目标表失败: " + err.Error()
result.Tables = append(result.Tables, summary)
return
}
pkCol := summary.PKColumn
targetMap := make(map[string]map[string]interface{}, len(targetRows))
for _, row := range targetRows {
if row[pkCol] == nil {
continue
}
pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol]))
if pkVal == "" || pkVal == "<nil>" {
continue
}
targetMap[pkVal] = row
}
sourcePKSet := make(map[string]struct{}, len(sourceRows))
for _, sRow := range sourceRows {
if sRow[pkCol] == nil {
continue
}
pkVal := strings.TrimSpace(fmt.Sprintf("%v", sRow[pkCol]))
if pkVal == "" || pkVal == "<nil>" {
continue
}
sourcePKSet[pkVal] = struct{}{}
if tRow, exists := targetMap[pkVal]; exists {
changed := false
for k, v := range sRow {
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
changed = true
break
}
}
if changed {
summary.Updates++
} else {
summary.Same++
}
} else {
summary.Inserts++
}
}
for pkVal := range targetMap {
if _, ok := sourcePKSet[pkVal]; !ok {
summary.Deletes++
}
}
summary.CanSync = true
result.Tables = append(result.Tables, summary)
}()
}
s.progress(config.JobID, totalTables, totalTables, "", "差异分析完成")
result.Message = fmt.Sprintf("已完成 %d 张表的差异分析", len(result.Tables))
return result
}

164
internal/sync/preview.go Normal file
View File

@@ -0,0 +1,164 @@
package sync
import (
"GoNavi-Wails/internal/db"
"fmt"
"strings"
)
type PreviewRow struct {
PK string `json:"pk"`
Row map[string]interface{} `json:"row"`
}
type PreviewUpdateRow struct {
PK string `json:"pk"`
ChangedColumns []string `json:"changedColumns"`
Source map[string]interface{} `json:"source"`
Target map[string]interface{} `json:"target"`
}
type TableDiffPreview struct {
Table string `json:"table"`
PKColumn string `json:"pkColumn"`
TotalInserts int `json:"totalInserts"`
TotalUpdates int `json:"totalUpdates"`
TotalDeletes int `json:"totalDeletes"`
Inserts []PreviewRow `json:"inserts"`
Updates []PreviewUpdateRow `json:"updates"`
Deletes []PreviewRow `json:"deletes"`
}
func (s *SyncEngine) Preview(config SyncConfig, tableName string, limit int) (TableDiffPreview, error) {
if limit <= 0 {
limit = 200
}
if limit > 500 {
limit = 500
}
sourceDB, err := db.NewDatabase(config.SourceConfig.Type)
if err != nil {
return TableDiffPreview{}, fmt.Errorf("初始化源数据库驱动失败: %w", err)
}
targetDB, err := db.NewDatabase(config.TargetConfig.Type)
if err != nil {
return TableDiffPreview{}, fmt.Errorf("初始化目标数据库驱动失败: %w", err)
}
if err := sourceDB.Connect(config.SourceConfig); err != nil {
return TableDiffPreview{}, fmt.Errorf("源数据库连接失败: %w", err)
}
defer sourceDB.Close()
if err := targetDB.Connect(config.TargetConfig); err != nil {
return TableDiffPreview{}, fmt.Errorf("目标数据库连接失败: %w", err)
}
defer targetDB.Close()
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
sourceQueryTable := qualifiedNameForQuery(config.SourceConfig.Type, sourceSchema, sourceTable, tableName)
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
cols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
if err != nil {
return TableDiffPreview{}, fmt.Errorf("获取源表字段失败: %w", err)
}
pkCols := make([]string, 0, 2)
for _, c := range cols {
if c.Key == "PRI" || c.Key == "PK" {
pkCols = append(pkCols, c.Name)
}
}
if len(pkCols) == 0 {
return TableDiffPreview{}, fmt.Errorf("无主键,不支持数据预览")
}
if len(pkCols) > 1 {
return TableDiffPreview{}, fmt.Errorf("复合主键(%s暂不支持数据预览", strings.Join(pkCols, ","))
}
pkCol := pkCols[0]
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.SourceConfig.Type, sourceQueryTable)))
if err != nil {
return TableDiffPreview{}, fmt.Errorf("读取源表失败: %w", err)
}
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable)))
if err != nil {
return TableDiffPreview{}, fmt.Errorf("读取目标表失败: %w", err)
}
targetMap := make(map[string]map[string]interface{}, len(targetRows))
for _, row := range targetRows {
if row[pkCol] == nil {
continue
}
pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol]))
if pkVal == "" || pkVal == "<nil>" {
continue
}
targetMap[pkVal] = row
}
out := TableDiffPreview{
Table: tableName,
PKColumn: pkCol,
TotalInserts: 0,
TotalUpdates: 0,
TotalDeletes: 0,
Inserts: make([]PreviewRow, 0),
Updates: make([]PreviewUpdateRow, 0),
Deletes: make([]PreviewRow, 0),
}
sourcePKSet := make(map[string]struct{}, len(sourceRows))
for _, sRow := range sourceRows {
if sRow[pkCol] == nil {
continue
}
pkVal := strings.TrimSpace(fmt.Sprintf("%v", sRow[pkCol]))
if pkVal == "" || pkVal == "<nil>" {
continue
}
sourcePKSet[pkVal] = struct{}{}
if tRow, exists := targetMap[pkVal]; exists {
changedColumns := make([]string, 0)
for k, v := range sRow {
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
changedColumns = append(changedColumns, k)
}
}
if len(changedColumns) > 0 {
out.TotalUpdates++
if len(out.Updates) < limit {
out.Updates = append(out.Updates, PreviewUpdateRow{
PK: pkVal,
ChangedColumns: changedColumns,
Source: sRow,
Target: tRow,
})
}
}
continue
}
out.TotalInserts++
if len(out.Inserts) < limit {
out.Inserts = append(out.Inserts, PreviewRow{PK: pkVal, Row: sRow})
}
}
for pkVal, row := range targetMap {
if _, ok := sourcePKSet[pkVal]; ok {
continue
}
out.TotalDeletes++
if len(out.Deletes) < limit {
out.Deletes = append(out.Deletes, PreviewRow{PK: pkVal, Row: row})
}
}
return out, nil
}

View File

@@ -0,0 +1,58 @@
package sync
import (
"GoNavi-Wails/internal/connection"
"fmt"
)
func filterRowsByPKSelection(pkCol string, rows []map[string]interface{}, enabled bool, selectedPKs []string) []map[string]interface{} {
if !enabled {
return nil
}
if len(rows) == 0 {
return rows
}
if len(selectedPKs) == 0 {
return rows
}
set := make(map[string]struct{}, len(selectedPKs))
for _, pk := range selectedPKs {
set[pk] = struct{}{}
}
out := make([]map[string]interface{}, 0, len(rows))
for _, row := range rows {
pkStr := fmt.Sprintf("%v", row[pkCol])
if _, ok := set[pkStr]; ok {
out = append(out, row)
}
}
return out
}
func filterUpdatesByPKSelection(pkCol string, updates []connection.UpdateRow, enabled bool, selectedPKs []string) []connection.UpdateRow {
if !enabled {
return nil
}
if len(updates) == 0 {
return updates
}
if len(selectedPKs) == 0 {
return updates
}
set := make(map[string]struct{}, len(selectedPKs))
for _, pk := range selectedPKs {
set[pk] = struct{}{}
}
out := make([]connection.UpdateRow, 0, len(updates))
for _, u := range updates {
pkStr := fmt.Sprintf("%v", u.Keys[pkCol])
if _, ok := set[pkStr]; ok {
out = append(out, u)
}
}
return out
}

View File

@@ -0,0 +1,97 @@
package sync
import (
"GoNavi-Wails/internal/connection"
"strings"
)
func collectRequiredColumns(inserts []map[string]interface{}, updates []connection.UpdateRow) map[string]string {
// key: lower(columnName), value: original columnName
required := make(map[string]string)
for _, row := range inserts {
for k := range row {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
if _, exists := required[key]; !exists {
required[key] = k
}
}
}
for _, u := range updates {
for k := range u.Values {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
if _, exists := required[key]; !exists {
required[key] = k
}
}
}
return required
}
func filterInsertRows(inserts []map[string]interface{}, allowedLower map[string]struct{}) []map[string]interface{} {
if len(inserts) == 0 || len(allowedLower) == 0 {
return inserts
}
out := make([]map[string]interface{}, 0, len(inserts))
for _, row := range inserts {
if len(row) == 0 {
out = append(out, row)
continue
}
n := make(map[string]interface{}, len(row))
for k, v := range row {
if _, ok := allowedLower[strings.ToLower(strings.TrimSpace(k))]; ok {
n[k] = v
}
}
out = append(out, n)
}
return out
}
func filterUpdateRows(updates []connection.UpdateRow, allowedLower map[string]struct{}) []connection.UpdateRow {
if len(updates) == 0 || len(allowedLower) == 0 {
return updates
}
out := make([]connection.UpdateRow, 0, len(updates))
for _, u := range updates {
if len(u.Values) == 0 {
continue
}
values := make(map[string]interface{}, len(u.Values))
for k, v := range u.Values {
if _, ok := allowedLower[strings.ToLower(strings.TrimSpace(k))]; ok {
values[k] = v
}
}
if len(values) == 0 {
continue
}
out = append(out, connection.UpdateRow{
Keys: u.Keys,
Values: values,
})
}
return out
}
func sanitizeMySQLColumnType(t string) string {
tt := strings.TrimSpace(t)
if tt == "" {
return "TEXT"
}
// 基础防护:避免把元数据中异常内容拼进 SQL。
if strings.ContainsAny(tt, "`;\n\r") {
return "TEXT"
}
return tt
}

View File

@@ -0,0 +1,101 @@
package sync
import (
"GoNavi-Wails/internal/db"
"fmt"
"strings"
)
func (s *SyncEngine) syncTableSchema(config SyncConfig, res *SyncResult, sourceDB db.Database, targetDB db.Database, tableName string) error {
targetType := strings.ToLower(strings.TrimSpace(config.TargetConfig.Type))
if targetType != "mysql" {
s.appendLog(config.JobID, res, "warn", fmt.Sprintf("目标数据库类型=%s 暂不支持结构同步,已跳过表 %s", config.TargetConfig.Type, tableName))
return nil
}
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
// 1) 获取源表字段
sourceCols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
if err != nil {
return fmt.Errorf("获取源表字段失败: %w", err)
}
// 2) 确保目标表存在
targetCols, err := targetDB.GetColumns(targetSchema, targetTable)
if err != nil {
sourceType := strings.ToLower(strings.TrimSpace(config.SourceConfig.Type))
if sourceType != "mysql" {
return fmt.Errorf("目标表不存在且源类型=%s 暂不支持自动建表: %w", config.SourceConfig.Type, err)
}
s.appendLog(config.JobID, res, "warn", fmt.Sprintf("目标表 %s 不存在,开始尝试创建表结构", tableName))
createSQL, errCreate := sourceDB.GetCreateStatement(sourceSchema, sourceTable)
if errCreate != nil || strings.TrimSpace(createSQL) == "" {
if errCreate == nil {
errCreate = fmt.Errorf("建表语句为空")
}
return fmt.Errorf("获取源表建表语句失败: %w", errCreate)
}
if _, errExec := targetDB.Exec(createSQL); errExec != nil {
return fmt.Errorf("创建目标表失败: %w", errExec)
}
s.appendLog(config.JobID, res, "info", fmt.Sprintf("目标表创建成功:%s", tableName))
targetCols, err = targetDB.GetColumns(targetSchema, targetTable)
if err != nil {
return fmt.Errorf("创建目标表后获取字段失败: %w", err)
}
}
targetColSet := make(map[string]struct{}, len(targetCols))
for _, c := range targetCols {
name := strings.ToLower(strings.TrimSpace(c.Name))
if name == "" {
continue
}
targetColSet[name] = struct{}{}
}
// 3) 补齐目标缺失字段(安全策略:新增字段统一允许 NULL
missing := make([]string, 0)
sourceType := strings.ToLower(strings.TrimSpace(config.SourceConfig.Type))
for _, c := range sourceCols {
colName := strings.TrimSpace(c.Name)
if colName == "" {
continue
}
lower := strings.ToLower(colName)
if _, ok := targetColSet[lower]; ok {
continue
}
missing = append(missing, colName)
colType := "TEXT"
if sourceType == "mysql" {
colType = sanitizeMySQLColumnType(c.Type)
}
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s NULL",
quoteQualifiedIdentByType("mysql", targetQueryTable),
quoteIdentByType("mysql", colName),
colType,
)
if _, err := targetDB.Exec(alterSQL); err != nil {
s.appendLog(config.JobID, res, "error", fmt.Sprintf(" -> 补字段失败:表=%s 字段=%s 错误=%v", tableName, colName, err))
continue
}
s.appendLog(config.JobID, res, "info", fmt.Sprintf(" -> 已补齐字段:表=%s 字段=%s 类型=%s", tableName, colName, colType))
}
if len(missing) == 0 {
s.appendLog(config.JobID, res, "info", fmt.Sprintf("表结构一致:%s", tableName))
} else {
s.appendLog(config.JobID, res, "info", fmt.Sprintf("表结构同步完成:%s新增字段 %d 个)", tableName, len(missing)))
}
return nil
}

View File

@@ -0,0 +1,109 @@
package sync
import "strings"
func normalizeSyncMode(mode string) string {
m := strings.ToLower(strings.TrimSpace(mode))
switch m {
case "", "insert_update":
return "insert_update"
case "insert_only":
return "insert_only"
case "full_overwrite":
return "full_overwrite"
default:
return "insert_update"
}
}
func quoteIdentByType(dbType string, ident string) string {
if ident == "" {
return ident
}
switch dbType {
case "mysql":
return "`" + strings.ReplaceAll(ident, "`", "``") + "`"
default:
return `"` + strings.ReplaceAll(ident, `"`, `""`) + `"`
}
}
func quoteQualifiedIdentByType(dbType string, ident string) string {
raw := strings.TrimSpace(ident)
if raw == "" {
return raw
}
parts := strings.Split(raw, ".")
if len(parts) <= 1 {
return quoteIdentByType(dbType, raw)
}
quotedParts := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
quotedParts = append(quotedParts, quoteIdentByType(dbType, part))
}
if len(quotedParts) == 0 {
return quoteIdentByType(dbType, raw)
}
return strings.Join(quotedParts, ".")
}
func normalizeSchemaAndTable(dbType string, dbName string, tableName string) (string, string) {
rawTable := strings.TrimSpace(tableName)
rawDB := strings.TrimSpace(dbName)
if rawTable == "" {
return rawDB, rawTable
}
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
schema := strings.TrimSpace(parts[0])
table := strings.TrimSpace(parts[1])
if schema != "" && table != "" {
return schema, table
}
}
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "postgres", "kingbase":
return "public", rawTable
default:
return rawDB, rawTable
}
}
func qualifiedNameForQuery(dbType string, schema string, table string, original string) string {
raw := strings.TrimSpace(original)
if raw == "" {
return raw
}
if strings.Contains(raw, ".") {
return raw
}
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "postgres", "kingbase":
s := strings.TrimSpace(schema)
if s == "" {
s = "public"
}
if table == "" {
return raw
}
return s + "." + table
case "mysql":
s := strings.TrimSpace(schema)
if s == "" || table == "" {
return table
}
return s + "." + table
default:
return table
}
}

View File

@@ -5,15 +5,21 @@ import (
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"fmt"
"sort"
"strings"
"time"
)
// SyncConfig defines the parameters for a synchronization task
type SyncConfig struct {
SourceConfig connection.ConnectionConfig `json:"sourceConfig"`
TargetConfig connection.ConnectionConfig `json:"targetConfig"`
Tables []string `json:"tables"` // Tables to sync
Mode string `json:"mode"` // "insert_update", "full_overwrite"
SourceConfig connection.ConnectionConfig `json:"sourceConfig"`
TargetConfig connection.ConnectionConfig `json:"targetConfig"`
Tables []string `json:"tables"` // Tables to sync
Content string `json:"content,omitempty"` // "data", "schema", "both"
Mode string `json:"mode"` // "insert_update", "insert_only", "full_overwrite"
JobID string `json:"jobId,omitempty"`
AutoAddColumns bool `json:"autoAddColumns,omitempty"` // 自动补齐缺失字段(当前仅 MySQL 目标支持)
TableOptions map[string]TableOptions `json:"tableOptions,omitempty"`
}
// SyncResult holds the result of the sync operation
@@ -28,21 +34,55 @@ type SyncResult struct {
}
type SyncEngine struct {
reporter Reporter
}
func NewSyncEngine() *SyncEngine {
return &SyncEngine{}
func NewSyncEngine(reporter Reporter) *SyncEngine {
return &SyncEngine{reporter: reporter}
}
// CompareAndSync performs the synchronization
func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
result := SyncResult{Success: true, Logs: []string{}}
logger.Infof("开始数据同步:源=%s 目标=%s 表数量=%d", formatConnSummaryForSync(config.SourceConfig), formatConnSummaryForSync(config.TargetConfig), len(config.Tables))
totalTables := len(config.Tables)
s.progress(config.JobID, 0, totalTables, "", "开始同步")
contentRaw := strings.ToLower(strings.TrimSpace(config.Content))
syncSchema := false
syncData := true
switch contentRaw {
case "", "data":
syncData = true
case "schema":
syncSchema = true
syncData = false
case "both":
syncSchema = true
syncData = true
default:
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("未知同步内容 %q已自动使用仅同步数据", config.Content))
syncData = true
}
modeRaw := strings.ToLower(strings.TrimSpace(config.Mode))
if modeRaw != "" && modeRaw != "insert_update" && modeRaw != "insert_only" && modeRaw != "full_overwrite" {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("未知同步模式 %q已自动使用 insert_update", config.Mode))
}
defaultMode := normalizeSyncMode(config.Mode)
contentLabel := "仅同步数据"
if syncSchema && syncData {
contentLabel = "同步结构+数据"
} else if syncSchema {
contentLabel = "仅同步结构"
}
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("同步内容:%s模式%s自动补字段%v", contentLabel, defaultMode, config.AutoAddColumns))
sourceDB, err := db.NewDatabase(config.SourceConfig.Type)
if err != nil {
logger.Error(err, "初始化源数据库驱动失败:类型=%s", config.SourceConfig.Type)
return s.fail(result, "初始化源数据库驱动失败: "+err.Error())
return s.fail(config.JobID, totalTables, result, "初始化源数据库驱动失败: "+err.Error())
}
if config.SourceConfig.Type == "custom" {
// Custom DB setup would go here if needed
@@ -51,133 +91,402 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
targetDB, err := db.NewDatabase(config.TargetConfig.Type)
if err != nil {
logger.Error(err, "初始化目标数据库驱动失败:类型=%s", config.TargetConfig.Type)
return s.fail(result, "初始化目标数据库驱动失败: "+err.Error())
return s.fail(config.JobID, totalTables, result, "初始化目标数据库驱动失败: "+err.Error())
}
// Connect Source
result.Logs = append(result.Logs, fmt.Sprintf("正在连接源数据库: %s...", config.SourceConfig.Host))
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("正在连接源数据库: %s...", config.SourceConfig.Host))
s.progress(config.JobID, 0, totalTables, "", "连接源数据库")
if err := sourceDB.Connect(config.SourceConfig); err != nil {
logger.Error(err, "源数据库连接失败:%s", formatConnSummaryForSync(config.SourceConfig))
return s.fail(result, "源数据库连接失败: "+err.Error())
return s.fail(config.JobID, totalTables, result, "源数据库连接失败: "+err.Error())
}
defer sourceDB.Close()
// Connect Target
result.Logs = append(result.Logs, fmt.Sprintf("正在连接目标数据库: %s...", config.TargetConfig.Host))
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("正在连接目标数据库: %s...", config.TargetConfig.Host))
s.progress(config.JobID, 0, totalTables, "", "连接目标数据库")
if err := targetDB.Connect(config.TargetConfig); err != nil {
logger.Error(err, "目标数据库连接失败:%s", formatConnSummaryForSync(config.TargetConfig))
return s.fail(result, "目标数据库连接失败: "+err.Error())
return s.fail(config.JobID, totalTables, result, "目标数据库连接失败: "+err.Error())
}
defer targetDB.Close()
// Iterate Tables
for _, tableName := range config.Tables {
result.Logs = append(result.Logs, fmt.Sprintf("正在同步表: %s", tableName))
for i, tableName := range config.Tables {
func() {
tableMode := defaultMode
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("正在同步表: %s", tableName))
s.progress(config.JobID, i, totalTables, tableName, fmt.Sprintf("同步表(%d/%d)", i+1, totalTables))
defer s.progress(config.JobID, i+1, totalTables, tableName, "表处理完成")
// 1. Get Columns & PKs (Naive approach: assume same schema)
cols, err := sourceDB.GetColumns(config.SourceConfig.Database, tableName)
if err != nil {
logger.Error(err, "获取源表列信息失败:表=%s", tableName)
result.Logs = append(result.Logs, fmt.Sprintf("获取表 %s 的列信息失败: %v", tableName, err))
continue
}
pkCol := ""
for _, col := range cols {
if col.Key == "PRI" || col.Key == "PK" {
pkCol = col.Name
break
if syncSchema {
s.progress(config.JobID, i, totalTables, tableName, "同步表结构")
if err := s.syncTableSchema(config, &result, sourceDB, targetDB, tableName); err != nil {
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("表结构同步失败:表=%s 错误=%v", tableName, err))
return
}
}
if !syncData {
result.TablesSynced++
return
}
}
if pkCol == "" {
result.Logs = append(result.Logs, fmt.Sprintf("跳过表 %s: 未找到主键 (同步需要主键)", tableName))
continue
}
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
sourceQueryTable := qualifiedNameForQuery(config.SourceConfig.Type, sourceSchema, sourceTable, tableName)
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
// 2. Fetch Data (MEMORY INTENSIVE - PROTOTYPE ONLY)
// TODO: Implement paging/streaming
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", tableName))
if err != nil {
logger.Error(err, "读取源表失败:表=%s", tableName)
result.Logs = append(result.Logs, fmt.Sprintf("读取源表 %s 失败: %v", tableName, err))
continue
}
// 1. Get Columns & PKs
cols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
if err != nil {
logger.Error(err, "获取源表列信息失败:表=%s", tableName)
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("获取表 %s 的列信息失败: %v", tableName, err))
return
}
sourceColsByLower := make(map[string]connection.ColumnDefinition, len(cols))
for _, col := range cols {
if strings.TrimSpace(col.Name) == "" {
continue
}
sourceColsByLower[strings.ToLower(strings.TrimSpace(col.Name))] = col
}
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", tableName))
if err != nil {
logger.Error(err, "读取目标表失败:表=%s", tableName)
// Table might not exist in target?
// Check if error is "table not found" -> Try to Create?
// For now, assume table exists.
result.Logs = append(result.Logs, fmt.Sprintf("读取目标表 %s 失败: %v", tableName, err))
continue
}
pkCols := make([]string, 0, 2)
for _, col := range cols {
if col.Key == "PRI" || col.Key == "PK" {
pkCols = append(pkCols, col.Name)
}
}
// 3. Compare (In-Memory Hash Map)
targetMap := make(map[string]map[string]interface{})
for _, row := range targetRows {
pkVal := fmt.Sprintf("%v", row[pkCol])
targetMap[pkVal] = row
}
if len(pkCols) == 0 {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("表 %s 未找到主键,已跳过数据同步(避免产生重复数据)", tableName))
return
}
if len(pkCols) > 1 {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("表 %s 为复合主键(%s当前暂不支持数据同步", tableName, strings.Join(pkCols, ",")))
return
}
pkCol := pkCols[0]
var inserts []map[string]interface{}
var updates []connection.UpdateRow
// var deletes []map[string]interface{} // Not implemented in "insert_update" mode usually
for _, sRow := range sourceRows {
pkVal := fmt.Sprintf("%v", sRow[pkCol])
if tRow, exists := targetMap[pkVal]; exists {
// Update? Compare values
// Simplified: Compare string representations or iterate keys
// For prototype: assume update if exists
// Optimization: Check diff
changes := make(map[string]interface{})
for k, v := range sRow {
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
changes[k] = v
opts := TableOptions{Insert: true, Update: true, Delete: false}
if config.TableOptions != nil {
if t, ok := config.TableOptions[tableName]; ok {
opts = t
// 默认防护:如用户未设置任意一个字段,保持 insert/update 默认 true、delete 默认 false
if !t.Insert && !t.Update && !t.Delete {
opts = t
}
}
if len(changes) > 0 {
updates = append(updates, connection.UpdateRow{
Keys: map[string]interface{}{pkCol: pkVal},
Values: changes,
})
}
} else {
// Insert
inserts = append(inserts, sRow)
}
}
if !opts.Insert && !opts.Update && !opts.Delete {
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("表 %s 未勾选任何操作,已跳过", tableName))
return
}
// 4. Apply Changes
changeSet := connection.ChangeSet{
Inserts: inserts,
Updates: updates,
}
// 2. Fetch Data (MEMORY INTENSIVE - PROTOTYPE ONLY)
// TODO: Implement paging/streaming
s.progress(config.JobID, i, totalTables, tableName, "读取源表数据")
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.SourceConfig.Type, sourceQueryTable)))
if err != nil {
logger.Error(err, "读取源表失败:表=%s", tableName)
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("读取源表 %s 失败: %v", tableName, err))
return
}
if len(inserts) > 0 || len(updates) > 0 {
result.Logs = append(result.Logs, fmt.Sprintf(" -> 需插入: %d 行, 需更新: %d 行", len(inserts), len(updates)))
var inserts []map[string]interface{}
var updates []connection.UpdateRow
// We need a BatchApplier interface or assume Database implements ApplyChanges
if applier, ok := targetDB.(db.BatchApplier); ok {
if err := applier.ApplyChanges(tableName, changeSet); err != nil {
result.Logs = append(result.Logs, fmt.Sprintf(" -> 应用变更失败: %v", err))
if tableMode == "insert_update" {
s.progress(config.JobID, i, totalTables, tableName, "读取目标表数据")
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable)))
if err != nil {
logger.Error(err, "读取目标表失败:表=%s", tableName)
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("读取目标表 %s 失败: %v", tableName, err))
return
}
// 3. Compare (In-Memory Hash Map)
s.progress(config.JobID, i, totalTables, tableName, "对比差异")
targetMap := make(map[string]map[string]interface{})
for _, row := range targetRows {
if row[pkCol] == nil {
continue
}
pkVal := fmt.Sprintf("%v", row[pkCol])
if strings.TrimSpace(pkVal) == "" || pkVal == "<nil>" {
continue
}
targetMap[pkVal] = row
}
sourcePKSet := make(map[string]struct{}, len(sourceRows))
for _, sRow := range sourceRows {
if sRow[pkCol] == nil {
continue
}
pkVal := fmt.Sprintf("%v", sRow[pkCol])
if strings.TrimSpace(pkVal) == "" || pkVal == "<nil>" {
continue
}
sourcePKSet[pkVal] = struct{}{}
if tRow, exists := targetMap[pkVal]; exists {
changes := make(map[string]interface{})
for k, v := range sRow {
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
changes[k] = v
}
}
if len(changes) > 0 {
updates = append(updates, connection.UpdateRow{
Keys: map[string]interface{}{pkCol: sRow[pkCol]},
Values: changes,
})
}
} else {
inserts = append(inserts, sRow)
}
}
var deletes []map[string]interface{}
if opts.Delete {
for pkStr, row := range targetMap {
if _, ok := sourcePKSet[pkStr]; ok {
continue
}
deletes = append(deletes, map[string]interface{}{pkCol: row[pkCol]})
}
}
// apply operation selection
inserts = filterRowsByPKSelection(pkCol, inserts, opts.Insert, opts.SelectedInsertPKs)
updates = filterUpdatesByPKSelection(pkCol, updates, opts.Update, opts.SelectedUpdatePKs)
deletes = filterRowsByPKSelection(pkCol, deletes, opts.Delete, opts.SelectedDeletePKs)
changeSet := connection.ChangeSet{
Inserts: inserts,
Updates: updates,
Deletes: deletes,
}
// 4. Align schema (target missing columns)
s.progress(config.JobID, i, totalTables, tableName, "检查字段一致性")
requiredCols := collectRequiredColumns(changeSet.Inserts, changeSet.Updates)
targetCols, err := targetDB.GetColumns(targetSchema, targetTable)
if err != nil {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 获取目标表字段失败,已跳过字段一致性检查: %v", err))
} else {
result.RowsInserted += len(inserts)
result.RowsUpdated += len(updates)
targetColSet := make(map[string]struct{}, len(targetCols))
for _, c := range targetCols {
name := strings.ToLower(strings.TrimSpace(c.Name))
if name == "" {
continue
}
targetColSet[name] = struct{}{}
}
missing := make([]string, 0)
for lower, original := range requiredCols {
if _, ok := targetColSet[lower]; !ok {
missing = append(missing, original)
}
}
sort.Strings(missing)
if len(missing) > 0 {
if config.AutoAddColumns && strings.ToLower(strings.TrimSpace(config.TargetConfig.Type)) == "mysql" {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个,开始自动补齐: %s", len(missing), strings.Join(missing, ", ")))
added := 0
for _, colName := range missing {
colLower := strings.ToLower(strings.TrimSpace(colName))
colType := "TEXT"
if strings.ToLower(strings.TrimSpace(config.SourceConfig.Type)) == "mysql" {
if srcCol, ok := sourceColsByLower[colLower]; ok {
colType = sanitizeMySQLColumnType(srcCol.Type)
}
}
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s NULL",
quoteQualifiedIdentByType("mysql", targetQueryTable),
quoteIdentByType("mysql", colName),
colType,
)
if _, err := targetDB.Exec(alterSQL); err != nil {
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 自动补字段失败:字段=%s 错误=%v", colName, err))
continue
}
added++
}
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 自动补字段完成:成功=%d 失败=%d", added, len(missing)-added))
// refresh columns
targetCols, err = targetDB.GetColumns(targetSchema, targetTable)
if err == nil {
targetColSet = make(map[string]struct{}, len(targetCols))
for _, c := range targetCols {
name := strings.ToLower(strings.TrimSpace(c.Name))
if name == "" {
continue
}
targetColSet[name] = struct{}{}
}
}
} else {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个(未开启自动补齐),将自动忽略:%s", len(missing), strings.Join(missing, ", ")))
}
// filter out still-missing columns to avoid apply failure
changeSet.Inserts = filterInsertRows(changeSet.Inserts, targetColSet)
changeSet.Updates = filterUpdateRows(changeSet.Updates, targetColSet)
}
}
// 5. Apply Changes
s.progress(config.JobID, i, totalTables, tableName, "应用变更")
if len(changeSet.Inserts) > 0 || len(changeSet.Updates) > 0 || len(changeSet.Deletes) > 0 {
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 需插入: %d 行, 需更新: %d 行, 需删除: %d 行", len(changeSet.Inserts), len(changeSet.Updates), len(changeSet.Deletes)))
if applier, ok := targetDB.(db.BatchApplier); ok {
if err := applier.ApplyChanges(targetTable, changeSet); err != nil {
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 应用变更失败: %v", err))
} else {
result.RowsInserted += len(changeSet.Inserts)
result.RowsUpdated += len(changeSet.Updates)
result.RowsDeleted += len(changeSet.Deletes)
}
} else {
s.appendLog(config.JobID, &result, "warn", " -> 目标驱动不支持应用数据变更 (ApplyChanges).")
}
} else {
s.appendLog(config.JobID, &result, "info", " -> 数据一致,无需变更.")
}
result.TablesSynced++
return
} else {
// insert_only / full_overwrite: do not compare target, just insert source rows
inserts = sourceRows
}
// full_overwrite: clear target table first
if tableMode == "full_overwrite" {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 全量覆盖模式:即将清空目标表 %s", tableName))
s.progress(config.JobID, i, totalTables, tableName, "清空目标表")
clearSQL := ""
if strings.ToLower(strings.TrimSpace(config.TargetConfig.Type)) == "mysql" {
clearSQL = fmt.Sprintf("TRUNCATE TABLE %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable))
} else {
clearSQL = fmt.Sprintf("DELETE FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable))
}
if _, err := targetDB.Exec(clearSQL); err != nil {
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 清空目标表失败: %v", err))
return
}
}
// 4. Align schema (target missing columns)
s.progress(config.JobID, i, totalTables, tableName, "检查字段一致性")
requiredCols := collectRequiredColumns(inserts, updates)
targetCols, err := targetDB.GetColumns(targetSchema, targetTable)
if err != nil {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 获取目标表字段失败,已跳过字段一致性检查: %v", err))
} else {
targetColSet := make(map[string]struct{}, len(targetCols))
for _, c := range targetCols {
name := strings.ToLower(strings.TrimSpace(c.Name))
if name == "" {
continue
}
targetColSet[name] = struct{}{}
}
missing := make([]string, 0)
for lower, original := range requiredCols {
if _, ok := targetColSet[lower]; !ok {
missing = append(missing, original)
}
}
sort.Strings(missing)
if len(missing) > 0 {
if config.AutoAddColumns && strings.ToLower(strings.TrimSpace(config.TargetConfig.Type)) == "mysql" {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个,开始自动补齐: %s", len(missing), strings.Join(missing, ", ")))
added := 0
for _, colName := range missing {
colLower := strings.ToLower(strings.TrimSpace(colName))
colType := "TEXT"
if strings.ToLower(strings.TrimSpace(config.SourceConfig.Type)) == "mysql" {
if srcCol, ok := sourceColsByLower[colLower]; ok {
colType = sanitizeMySQLColumnType(srcCol.Type)
}
}
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s NULL",
quoteQualifiedIdentByType("mysql", targetQueryTable),
quoteIdentByType("mysql", colName),
colType,
)
if _, err := targetDB.Exec(alterSQL); err != nil {
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 自动补字段失败:字段=%s 错误=%v", colName, err))
continue
}
added++
}
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 自动补字段完成:成功=%d 失败=%d", added, len(missing)-added))
// refresh columns
targetCols, err = targetDB.GetColumns(targetSchema, targetTable)
if err == nil {
targetColSet = make(map[string]struct{}, len(targetCols))
for _, c := range targetCols {
name := strings.ToLower(strings.TrimSpace(c.Name))
if name == "" {
continue
}
targetColSet[name] = struct{}{}
}
}
} else {
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个(未开启自动补齐),将自动忽略:%s", len(missing), strings.Join(missing, ", ")))
}
// filter out still-missing columns to avoid apply failure
inserts = filterInsertRows(inserts, targetColSet)
updates = filterUpdateRows(updates, targetColSet)
}
}
// 5. Apply Changes
s.progress(config.JobID, i, totalTables, tableName, "应用变更")
changeSet := connection.ChangeSet{
Inserts: inserts,
Updates: updates,
}
if len(changeSet.Inserts) > 0 || len(changeSet.Updates) > 0 {
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 需插入: %d 行, 需更新: %d 行", len(changeSet.Inserts), len(changeSet.Updates)))
if applier, ok := targetDB.(db.BatchApplier); ok {
if err := applier.ApplyChanges(targetTable, changeSet); err != nil {
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 应用变更失败: %v", err))
} else {
result.RowsInserted += len(changeSet.Inserts)
result.RowsUpdated += len(changeSet.Updates)
}
} else {
s.appendLog(config.JobID, &result, "warn", " -> 目标驱动不支持应用数据变更 (ApplyChanges).")
}
} else {
result.Logs = append(result.Logs, " -> 目标驱动不支持应用数据变更 (ApplyChanges).")
s.appendLog(config.JobID, &result, "info", " -> 数据一致,无需变更.")
}
} else {
result.Logs = append(result.Logs, " -> 数据一致,无需变更.")
}
result.TablesSynced++
result.TablesSynced++
}()
}
s.progress(config.JobID, totalTables, totalTables, "", "同步完成")
return result
}
@@ -196,9 +505,52 @@ func formatConnSummaryForSync(config connection.ConnectionConfig) string {
config.Type, config.Host, config.Port, dbName, config.User, timeoutSeconds)
}
func (s *SyncEngine) fail(res SyncResult, msg string) SyncResult {
func (s *SyncEngine) appendLog(jobID string, res *SyncResult, level string, msg string) {
if res != nil {
res.Logs = append(res.Logs, msg)
}
if s.reporter.OnLog != nil && strings.TrimSpace(jobID) != "" {
s.reporter.OnLog(SyncLogEvent{
JobID: jobID,
Level: level,
Message: msg,
Ts: time.Now().UnixMilli(),
})
}
}
func (s *SyncEngine) progress(jobID string, current, total int, table string, stage string) {
if s.reporter.OnProgress == nil || strings.TrimSpace(jobID) == "" {
return
}
percent := 0
if total <= 0 {
if current > 0 {
percent = 100
}
} else {
if current < 0 {
current = 0
}
if current > total {
current = total
}
percent = (current * 100) / total
}
s.reporter.OnProgress(SyncProgressEvent{
JobID: jobID,
Percent: percent,
Current: current,
Total: total,
Table: table,
Stage: stage,
})
}
func (s *SyncEngine) fail(jobID string, totalTables int, res SyncResult, msg string) SyncResult {
res.Success = false
res.Message = msg
res.Logs = append(res.Logs, "致命错误: "+msg)
s.appendLog(jobID, &res, "error", "致命错误: "+msg)
s.progress(jobID, res.TablesSynced, totalTables, "", "同步失败")
return res
}

View File

@@ -0,0 +1,30 @@
package sync
const (
EventSyncStart = "sync:start"
EventSyncProgress = "sync:progress"
EventSyncLog = "sync:log"
EventSyncDone = "sync:done"
)
type SyncLogEvent struct {
JobID string `json:"jobId"`
Level string `json:"level"` // info/warn/error
Message string `json:"message"`
Ts int64 `json:"ts"` // Unix milli
}
type SyncProgressEvent struct {
JobID string `json:"jobId"`
Percent int `json:"percent"`
Current int `json:"current"` // 已完成表数
Total int `json:"total"` // 总表数
Table string `json:"table,omitempty"`
Stage string `json:"stage,omitempty"`
}
type Reporter struct {
OnLog func(event SyncLogEvent)
OnProgress func(event SyncProgressEvent)
}

View File

@@ -0,0 +1,13 @@
package sync
// TableOptions controls which operations to apply per table, and optional row selection.
// 注意:如未指定 Selected*PKs则表示“同步全部该类型差异数据”如指定为空数组则同样表示全部。
type TableOptions struct {
Insert bool `json:"insert,omitempty"`
Update bool `json:"update,omitempty"`
Delete bool `json:"delete,omitempty"`
SelectedInsertPKs []string `json:"selectedInsertPks,omitempty"`
SelectedUpdatePKs []string `json:"selectedUpdatePks,omitempty"`
SelectedDeletePKs []string `json:"selectedDeletePks,omitempty"`
}

52
logo.svg Normal file
View File

@@ -0,0 +1,52 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
<defs>
<!-- Background: Soft Light Grey -->
<linearGradient id="bgSoft" x1="0%" y1="0%" x2="0%" y2="100%">
<stop offset="0%" style="stop-color:#f5f7fa;stop-opacity:1" />
<stop offset="100%" style="stop-color:#c3cfe2;stop-opacity:1" />
</linearGradient>
<!-- Hexagon: Solid Tech Pink -->
<linearGradient id="solidPink" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" style="stop-color:#FF5F6D;stop-opacity:1" />
<stop offset="100%" style="stop-color:#FFC371;stop-opacity:1" />
</linearGradient>
<!-- N: Solid Tech Blue/Cyan -->
<linearGradient id="solidCyan" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" style="stop-color:#00c6ff;stop-opacity:1" />
<stop offset="100%" style="stop-color:#0072ff;stop-opacity:1" />
</linearGradient>
<filter id="hardShadow" x="-20%" y="-20%" width="140%" height="140%">
<feGaussianBlur in="SourceAlpha" stdDeviation="4"/>
<feOffset dx="4" dy="4" result="offsetblur"/>
<feComponentTransfer>
<feFuncA type="linear" slope="0.2"/>
</feComponentTransfer>
<feMerge>
<feMergeNode/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
</defs>
<!-- Background -->
<rect x="32" y="32" width="448" height="448" rx="100" fill="url(#bgSoft)" />
<!-- Main Content Centered -->
<g transform="translate(106, 106) scale(0.6)" filter="url(#hardShadow)">
<!-- Hex G -->
<path d="M 250 0 L 466 125 L 466 375 L 250 500 L 34 375 L 34 125 Z"
fill="none" stroke="url(#solidPink)" stroke-width="45" stroke-linejoin="round"/>
<!-- G Crossbar -->
<path d="M 466 300 L 330 300" stroke="url(#solidPink)" stroke-width="45" stroke-linecap="round"/>
<!-- Inner N -->
<path d="M 160 350 L 160 150 L 340 350 L 340 150"
fill="none" stroke="url(#solidCyan)" stroke-width="50" stroke-linecap="round" stroke-linejoin="round"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.0 KiB