diff --git a/docs/HighGo_Optional_Code_Changes.md b/docs/HighGo_Optional_Code_Changes.md new file mode 100644 index 0000000..2fcd44e --- /dev/null +++ b/docs/HighGo_Optional_Code_Changes.md @@ -0,0 +1,164 @@ +# HighGo 可选代码优化建议 + +## 一、sslmode 配置优化 + +### 当前状态 + +**文件**:`internal/db/highgo_impl.go:43` + +**当前代码**: +```go +q.Set("sslmode", "disable") +``` + +### 建议修改 + +根据瀚高官方文档,sslmode 的默认值应该是 `require`。建议修改为: + +```go +q.Set("sslmode", "require") +``` + +### 修改原因 + +1. **符合官方规范**:瀚高官方文档明确指出默认 sslmode 为 `require` +2. **安全性提升**:启用 SSL 加密可以保护数据传输安全 +3. **生产环境最佳实践**:生产环境应该启用 SSL 连接 + +### 是否需要修改? + +**不一定需要修改**,取决于您的实际环境: + +#### 保持 `disable` 的场景: +- ✅ 开发/测试环境 +- ✅ HighGo 服务器未配置 SSL 证书 +- ✅ 内网环境,不需要加密传输 +- ✅ 快速测试连接功能 + +#### 修改为 `require` 的场景: +- ✅ 生产环境 +- ✅ HighGo 服务器已配置 SSL 证书 +- ✅ 跨网络连接,需要加密保护 +- ✅ 符合安全合规要求 + +### 如何修改 + +如果您决定修改,可以使用以下命令: + +**方式 1:直接修改(固定为 require)** +```go +// 文件:internal/db/highgo_impl.go 第 43 行 +q.Set("sslmode", "require") +``` + +**方式 2:可配置(推荐)** + +如果希望让用户可以选择 sslmode,可以修改为: + +```go +// 在 getDSN 方法中 +sslmode := "disable" // 默认值 +if config.SSLMode != "" { + sslmode = config.SSLMode +} +q.Set("sslmode", sslmode) +``` + +然后在 `internal/connection/connection.go` 的 `ConnectionConfig` 结构体中添加字段: + +```go +type ConnectionConfig struct { + // ... 现有字段 + SSLMode string `json:"sslMode,omitempty"` // SSL 模式:disable, require, verify-ca, verify-full +} +``` + +前端 UI 也需要相应添加 sslmode 选择控件。 + +### 测试建议 + +修改后请务必测试: + +1. **SSL 启用测试**: + - 连接配置了 SSL 的 HighGo 服务器 + - 验证连接成功 + +2. **SSL 禁用测试**: + - 连接未配置 SSL 的 HighGo 服务器 + - 验证是否会报错(如果设置为 `require` 会报错) + +3. **兼容性测试**: + - 测试现有的 HighGo 连接配置是否仍然可用 + +## 二、其他可选优化 + +### 1. 默认端口提示优化 + +**文件**:`frontend/src/components/ConnectionModal.tsx` + +**当前状态**:HighGo 的默认端口已正确设置为 5866 + +**建议**:无需修改,已符合官方规范 + +### 2. 默认数据库名称 + +**文件**:`internal/db/highgo_impl.go:33` + +**当前代码**: +```go +if dbname == "" { + dbname = "highgo" // HighGo default database +} +``` + +**建议**:无需修改,已符合官方规范(默认数据库为 `highgo`) + +### 3. 默认用户名 + +**当前状态**:未在代码中硬编码默认用户名 + +**瀚高官方默认**:`sysdba` + +**建议**: +- 可以在前端 UI 的 HighGo 连接表单中,将用户名输入框的 placeholder 设置为 `sysdba` +- 但不建议硬编码默认值,让用户自行输入更安全 + +## 三、总结 + +### 必须修改的项目 +- ✅ **无**(当前代码已基本符合规范) + +### 建议修改的项目 +1. **sslmode 配置**(根据实际环境决定) + - 开发环境:保持 `disable` + - 生产环境:修改为 `require` + +### 可选优化的项目 +1. 将 sslmode 改为可配置(需要修改前后端) +2. 前端 UI 添加 sslmode 选择控件 +3. 用户名输入框添加 `sysdba` 提示 + +## 四、修改优先级 + +**优先级 1(高)**: +- 集成瀚高 SM3 驱动(参考 `HighGo_SM3_Integration_Guide.md`) + +**优先级 2(中)**: +- 根据部署环境调整 sslmode 配置 + +**优先级 3(低)**: +- 将 sslmode 改为可配置 +- UI 优化(placeholder 提示等) + +## 五、下一步行动 + +建议按以下顺序执行: + +1. **先集成 SM3 驱动**(参考集成指南) +2. **测试基本连接功能**(使用 sslmode=disable) +3. **如果生产环境需要 SSL**,再修改 sslmode 配置 +4. **验证所有功能正常**后,考虑可选优化项 + +--- + +**注意**:所有代码修改都应该在集成 SM3 驱动并验证基本功能正常后再进行。 diff --git a/docs/HighGo_SM3_Integration_Guide.md b/docs/HighGo_SM3_Integration_Guide.md new file mode 100644 index 0000000..c25034e --- /dev/null +++ b/docs/HighGo_SM3_Integration_Guide.md @@ -0,0 +1,179 @@ +# HighGo SM3 国密驱动集成指南 + +## 一、背景说明 + +HighGo(瀚高)数据库需要使用支持 SM3 国密认证的 PostgreSQL 驱动。瀚高官方提供了基于 `lib/pq` 的安全增强版本。 + +## 二、集成步骤 + +### 步骤 1:下载瀚高 pq 驱动 + +1. 访问百度网盘链接: + ``` + https://pan.baidu.com/s/1xuz6uJz0utRgKWecXhpOiA?pwd=o0tj + ``` + +2. 下载驱动源码压缩包 + +### 步骤 2:放置驱动源码 + +1. 在项目根目录创建 vendor 目录(如果不存在): + ```bash + mkdir -p vendor/highgo-pq + ``` + +2. 解压下载的驱动源码到 `vendor/highgo-pq/` 目录 + +3. 确保目录结构如下: + ``` + GoNavi/ + ├── vendor/ + │ └── highgo-pq/ + │ ├── go.mod + │ ├── conn.go + │ ├── ... (其他 pq 驱动源文件) + ``` + +### 步骤 3:修改 go.mod + +在 `go.mod` 文件末尾添加 replace 指令: + +```go +replace github.com/lib/pq => ./vendor/highgo-pq +``` + +完整示例: +```go +module GoNavi-Wails + +go 1.24.3 + +require ( + // ... 现有依赖 + github.com/lib/pq v1.11.1 + // ... 其他依赖 +) + +// 在文件末尾添加 +replace github.com/lib/pq => ./vendor/highgo-pq +``` + +### 步骤 4:更新 HighGo 连接配置(可选) + +根据瀚高官方文档,建议修改 `internal/db/highgo_impl.go:43` 的 sslmode: + +**当前代码**: +```go +q.Set("sslmode", "disable") +``` + +**建议修改为**(瀚高默认): +```go +q.Set("sslmode", "require") +``` + +> ⚠️ 注意:如果您的 HighGo 服务器未配置 SSL,保持 `disable` 即可。 + +### 步骤 5:验证集成 + +1. 清理依赖缓存: + ```bash + go clean -modcache + ``` + +2. 重新下载依赖: + ```bash + go mod download + ``` + +3. 编译项目: + ```bash + go build ./... + ``` + +4. 测试 HighGo 连接: + - 启动应用 + - 创建 HighGo 连接 + - 测试连接是否成功 + +## 三、重要说明 + +### ⚠️ 影响范围 + +使用 `go.mod replace` 会**全局替换** `github.com/lib/pq` 驱动,这意味着: + +1. **PostgreSQL 连接也会使用瀚高驱动** +2. **需要验证瀚高驱动对标准 PostgreSQL 的兼容性** + +### 兼容性验证 + +集成后,请务必测试: + +1. ✅ HighGo 数据库连接(SM3 认证) +2. ✅ 标准 PostgreSQL 连接(确保仍然可用) + +如果标准 PostgreSQL 连接失败,说明瀚高驱动不完全兼容,需要考虑其他方案。 + +### 回滚方案 + +如果集成后出现问题,可以快速回滚: + +1. 删除 `go.mod` 中的 replace 指令 +2. 删除 `vendor/highgo-pq/` 目录 +3. 运行 `go mod tidy` +4. 重新编译 + +## 四、瀚高驱动特性 + +根据官方文档: + +- **包路径**:`github.com/lib/pq`(与标准版相同) +- **驱动名**:`postgres`(与标准版相同) +- **SM3 支持**:自动启用国密认证 +- **默认端口**:5866 +- **默认数据库**:`highgo` +- **默认用户**:`sysdba` +- **sslmode 默认**:`require` + +## 五、故障排查 + +### 问题 1:编译失败 + +**现象**:`go build` 报错找不到 `github.com/lib/pq` + +**解决**: +1. 检查 `vendor/highgo-pq/` 目录是否存在 +2. 检查 `go.mod` 中 replace 路径是否正确 +3. 运行 `go mod download` + +### 问题 2:HighGo 连接失败 + +**现象**:连接 HighGo 时报认证错误 + +**解决**: +1. 确认瀚高驱动已正确替换(检查 `go.mod`) +2. 确认 HighGo 服务器支持 SM3 认证 +3. 检查用户名、密码、端口是否正确 + +### 问题 3:PostgreSQL 连接失败 + +**现象**:集成后标准 PostgreSQL 无法连接 + +**解决**: +1. 这说明瀚高驱动不完全兼容标准 PostgreSQL +2. 需要考虑条件编译或其他隔离方案 +3. 临时回滚:删除 replace 指令 + +## 六、后续优化建议 + +如果发现瀚高驱动与标准 PostgreSQL 不兼容,可以考虑: + +1. **条件编译**:使用 Go build tags 分别编译两个版本 +2. **动态驱动注册**:如果瀚高驱动支持自定义驱动名 +3. **联系瀚高技术支持**:咨询官方兼容性方案 + +## 七、参考资料 + +- 瀚高官方文档:https://www.highgo.com/document/zh-cn/application/pq%E6%8E%A5%E5%8F%A3.html +- 瀚高驱动下载:https://pan.baidu.com/s/1xuz6uJz0utRgKWecXhpOiA?pwd=o0tj +- 标准 lib/pq:https://github.com/lib/pq diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index 0f8f4fe..a7661c0 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -5b8157374dae5f9340e31b2d0bd2c00e \ No newline at end of file +d0f9366af59a6367ad3c7e2d4185ead4 \ No newline at end of file diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 28858ee..4a43c03 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -10,7 +10,7 @@ import DataSyncModal from './components/DataSyncModal'; import LogPanel from './components/LogPanel'; import { useStore } from './store'; import { SavedConnection } from './types'; -import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform } from './utils/appearance'; +import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform } from './utils/appearance'; import './App.css'; const { Sider, Content } = Layout; @@ -814,19 +814,27 @@ function App() {
高斯模糊 (Blur)
-
- setAppearance({ blur: v })} - style={{ flex: 1 }} - /> - {appearance.blur}px -
-
- * 仅控制应用内覆盖层的模糊效果 -
+ {isWindowsPlatform() ? ( +
+ Windows 使用系统 Acrylic 效果,模糊程度由系统控制 +
+ ) : ( + <> +
+ setAppearance({ blur: v })} + style={{ flex: 1 }} + /> + {appearance.blur}px +
+
+ * 仅控制应用内覆盖层的模糊效果 +
+ + )}
diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 26f3152..4b6f9ba 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -14,6 +14,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal const [useSSH, setUseSSH] = useState(false); const [dbType, setDbType] = useState('mysql'); const [step, setStep] = useState(1); // 1: Select Type, 2: Configure + const [activeGroup, setActiveGroup] = useState(0); // Active category index in step 1 const [testResult, setTestResult] = useState<{ type: 'success' | 'error', message: string } | null>(null); const [dbList, setDbList] = useState([]); const [redisDbList, setRedisDbList] = useState([]); // Redis databases 0-15 @@ -62,6 +63,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal form.resetFields(); setUseSSH(false); setDbType('mysql'); + setActiveGroup(0); } } }, [open, initialValues]); @@ -195,6 +197,11 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal case 'oracle': defaultPort = 1521; break; case 'dameng': defaultPort = 5236; break; case 'kingbase': defaultPort = 54321; break; + case 'sqlserver': defaultPort = 1433; break; + case 'mongodb': defaultPort = 27017; break; + case 'highgo': defaultPort = 5866; break; + case 'mariadb': defaultPort = 3306; break; + case 'vastbase': defaultPort = 5432; break; default: defaultPort = 3306; } if (type !== 'sqlite' && type !== 'custom') { @@ -208,32 +215,75 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal const isCustom = dbType === 'custom'; const isRedis = dbType === 'redis'; - const dbTypes = [ - { key: 'mysql', name: 'MySQL', icon: }, - { key: 'postgres', name: 'PostgreSQL', icon: }, - { key: 'redis', name: 'Redis', icon: }, - { key: 'sqlite', name: 'SQLite', icon: }, - { key: 'oracle', name: 'Oracle', icon: }, - { key: 'dameng', name: 'Dameng (达梦)', icon: }, - { key: 'kingbase', name: 'Kingbase (人大金仓)', icon: }, - { key: 'custom', name: 'Custom (自定义)', icon: }, + const dbTypeGroups = [ + { label: '关系型数据库', items: [ + { key: 'mysql', name: 'MySQL', icon: }, + { key: 'mariadb', name: 'MariaDB', icon: }, + { key: 'postgres', name: 'PostgreSQL', icon: }, + { key: 'sqlserver', name: 'SQL Server', icon: }, + { key: 'sqlite', name: 'SQLite', icon: }, + { key: 'oracle', name: 'Oracle', icon: }, + ]}, + { label: '国产数据库', items: [ + { key: 'dameng', name: 'Dameng (达梦)', icon: }, + { key: 'kingbase', name: 'Kingbase (人大金仓)', icon: }, + { key: 'highgo', name: 'HighGo (瀚高)', icon: }, + { key: 'vastbase', name: 'Vastbase (海量)', icon: }, + ]}, + { label: 'NoSQL', items: [ + { key: 'mongodb', name: 'MongoDB', icon: }, + { key: 'redis', name: 'Redis', icon: }, + ]}, + { label: '其他', items: [ + { key: 'custom', name: 'Custom (自定义)', icon: }, + ]}, ]; + const dbTypes = dbTypeGroups.flatMap(g => g.items); + const renderStep1 = () => ( - - {dbTypes.map(item => ( - - handleTypeSelect(item.key)} - style={{ textAlign: 'center', cursor: 'pointer' }} +
+ {/* 左侧分类导航 */} +
+ {dbTypeGroups.map((group, idx) => ( +
setActiveGroup(idx)} + style={{ + padding: '10px 12px', + cursor: 'pointer', + borderRadius: 6, + marginBottom: 4, + background: activeGroup === idx ? '#e6f4ff' : 'transparent', + color: activeGroup === idx ? '#1677ff' : undefined, + fontWeight: activeGroup === idx ? 500 : 400, + transition: 'all 0.2s', + fontSize: 13, + }} > -
{item.icon}
- {item.name} - - - ))} - + {group.label} +
+ ))} +
+ {/* 右侧数据源卡片 */} +
+ + {dbTypeGroups[activeGroup]?.items.map(item => ( + + handleTypeSelect(item.key)} + style={{ textAlign: 'center', cursor: 'pointer', height: 100 }} + styles={{ body: { padding: '16px 8px', display: 'flex', flexDirection: 'column', alignItems: 'center', justifyContent: 'center', height: '100%' } }} + > +
{item.icon}
+ {item.name} +
+ + ))} +
+
+
); const renderStep2 = () => ( @@ -401,15 +451,16 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal }; return ( - {step === 1 ? renderStep1() : renderStep2()} diff --git a/frontend/src/utils/appearance.ts b/frontend/src/utils/appearance.ts index f140bf5..10d48b5 100644 --- a/frontend/src/utils/appearance.ts +++ b/frontend/src/utils/appearance.ts @@ -2,10 +2,10 @@ const DEFAULT_OPACITY = 1.0; const MIN_OPACITY = 0.1; const MAX_OPACITY = 1.0; -// macOS 端进一步增强通透感:同滑块值下更低等效不透明度、降低过重模糊。 -const MAC_OPACITY_FACTOR = 0.20; +// 平台透明度映射因子:值越大,滑块变化越平滑(1.0 = 线性映射) +const MAC_OPACITY_FACTOR = 0.60; const MAC_BLUR_FACTOR = 1.00; -const WINDOWS_OPACITY_FACTOR = 0.20; +const WINDOWS_OPACITY_FACTOR = 0.70; const WINDOWS_BLUR_FACTOR = 1.00; const clamp = (value: number, min: number, max: number) => Math.min(max, Math.max(min, value)); diff --git a/go.mod b/go.mod index 3fe8b39..0ef3911 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,11 @@ require ( gitee.com/chunanyong/dm v1.8.22 github.com/go-sql-driver/mysql v1.9.3 github.com/lib/pq v1.11.1 + github.com/microsoft/go-mssqldb v1.9.6 github.com/redis/go-redis/v9 v9.17.3 github.com/sijms/go-ora/v2 v2.9.0 github.com/wailsapp/wails/v2 v2.11.0 + go.mongodb.org/mongo-driver/v2 v2.5.0 golang.org/x/crypto v0.47.0 modernc.org/sqlite v1.44.3 ) @@ -22,10 +24,13 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect + github.com/klauspost/compress v1.17.6 // indirect github.com/labstack/echo/v4 v4.13.3 // indirect github.com/labstack/gommon v0.4.2 // indirect github.com/leaanthony/go-ansi-parser v1.6.1 // indirect @@ -40,13 +45,19 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/samber/lo v1.49.1 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/tkrajina/go-reflector v0.5.8 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect github.com/wailsapp/go-webview2 v1.0.22 // indirect github.com/wailsapp/mimetype v1.4.1 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.2.0 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/net v0.48.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect modernc.org/libc v1.67.6 // indirect diff --git a/go.sum b/go.sum index c1d9449..2ece0ff 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,18 @@ gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3 h1:QjslQNaH5Nuap5i4ni gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3/go.mod h1:7lH5A1jzCXD9Nl16DzaBUOfDAT8NPrDmZwKu1p5wf94= gitee.com/chunanyong/dm v1.8.22 h1:H7fsrnUIvEA0jlDWew7vwELry1ff+tLMIu2Fk2cIBSg= gitee.com/chunanyong/dm v1.8.22/go.mod h1:EPRJnuPFgbyOFgJ0TRYCTGzhq+ZT4wdyaj/GW/LLcNg= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY= github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -24,9 +36,17 @@ github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1 github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -37,6 +57,10 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e h1:Q3+PugElBCf4PFpxhErSzU3/PY5sFL5Z6rfv4AbGAck= github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e/go.mod h1:alcuEEnZsY1WQsagKhZDsoPCRoOijYqhZvPwLG0kzVs= +github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= +github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY= github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= @@ -61,6 +85,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/microsoft/go-mssqldb v1.9.6 h1:1MNQg5UiSsokiPz3++K2KPx4moKrwIqly1wv+RyCKTw= +github.com/microsoft/go-mssqldb v1.9.6/go.mod h1:yYMPDufyoF2vVuVCUGtZARr06DKFIhMrluTcgWlXpr4= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -78,6 +104,8 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/samber/lo v1.49.1 h1:4BIFyVfuQSEpluc7Fua+j1NolZHiEHEpaSEKdsH0tew= github.com/samber/lo v1.49.1/go.mod h1:dO6KHFzUKXgP8LDhU0oI8d2hekjXnGOu0DB8Jecxd6o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sijms/go-ora/v2 v2.9.0 h1:+iQbUeTeCOFMb5BsOMgUhV8KWyrv9yjKpcK4x7+MFrg= github.com/sijms/go-ora/v2 v2.9.0/go.mod h1:QgFInVi3ZWyqAiJwzBQA+nbKYKH77tdp1PYoCqhR2dU= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -94,35 +122,66 @@ github.com/wailsapp/mimetype v1.4.1 h1:pQN9ycO7uo4vsUUuPeHEYoUkLVkaRntMnHJxVwYhw github.com/wailsapp/mimetype v1.4.1/go.mod h1:9aV5k31bBOv5z6u+QP8TltzvNGJPmNJD4XlAL3U+j3o= github.com/wailsapp/wails/v2 v2.11.0 h1:seLacV8pqupq32IjS4Y7V8ucab0WZwtK6VvUVxSBtqQ= github.com/wailsapp/wails/v2 v2.11.0/go.mod h1:jrf0ZaM6+GBc1wRmXsM8cIvzlg0karYin3erahI4+0k= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= +github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= +go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210505024714-0287a6fb4125/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200810151505-1b9f1253b3ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= diff --git a/internal/app/db_context.go b/internal/app/db_context.go index 684c7a3..1b87b01 100644 --- a/internal/app/db_context.go +++ b/internal/app/db_context.go @@ -14,8 +14,8 @@ func normalizeRunConfig(config connection.ConnectionConfig, dbName string) conne } switch strings.ToLower(strings.TrimSpace(config.Type)) { - case "mysql", "postgres", "kingbase": - // 这些类型的 dbName 表示“数据库”,需要写入连接配置以选择目标库。 + case "mysql", "mariadb", "postgres", "kingbase", "highgo", "vastbase", "sqlserver", "mongodb": + // 这些类型的 dbName 表示"数据库",需要写入连接配置以选择目标库。 runConfig.Database = name case "dameng": // 达梦使用 schema 参数,沿用现有行为:dbName 表示 schema。 @@ -45,9 +45,12 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string, } switch strings.ToLower(strings.TrimSpace(config.Type)) { - case "postgres", "kingbase": - // PG/金仓:dbName 在 UI 里是“数据库”,schema 需从 tableName 或使用默认 public。 + case "postgres", "kingbase", "highgo", "vastbase": + // PG/金仓/瀚高/海量:dbName 在 UI 里是"数据库",schema 需从 tableName 或使用默认 public。 return "public", rawTable + case "sqlserver": + // SQL Server:dbName 表示数据库,schema 默认 dbo + return "dbo", rawTable default: // MySQL:dbName 表示数据库;Oracle/达梦:dbName 表示 schema/owner。 return rawDB, rawTable diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 334abb8..498bb48 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -47,9 +47,12 @@ func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string) escapedDbName := strings.ReplaceAll(dbName, "`", "``") query := fmt.Sprintf("CREATE DATABASE `%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", escapedDbName) - if runConfig.Type == "postgres" { + dbType := strings.ToLower(strings.TrimSpace(runConfig.Type)) + if dbType == "postgres" || dbType == "kingbase" || dbType == "highgo" || dbType == "vastbase" { escapedDbName = strings.ReplaceAll(dbName, `"`, `""`) query = fmt.Sprintf("CREATE DATABASE \"%s\"", escapedDbName) + } else if dbType == "mariadb" { + // MariaDB uses same syntax as MySQL } _, err = dbInst.Exec(query) @@ -95,7 +98,7 @@ func normalizeSchemaAndTableByType(dbType string, dbName string, tableName strin } switch dbType { - case "postgres", "kingbase": + case "postgres", "kingbase", "highgo", "vastbase": return "public", rawTable default: return rawDB, rawTable @@ -116,7 +119,7 @@ func buildRunConfigForDDL(config connection.ConnectionConfig, dbType string, dbN if strings.EqualFold(strings.TrimSpace(config.Type), "custom") { // custom 连接的 dbName 语义依赖 driver,尽量在常见驱动上对齐内置类型行为。 switch dbType { - case "mysql", "postgres", "kingbase", "dameng": + case "mysql", "mariadb", "postgres", "kingbase", "vastbase", "dameng": if strings.TrimSpace(dbName) != "" { runConfig.Database = strings.TrimSpace(dbName) } @@ -137,9 +140,9 @@ func (a *App) RenameDatabase(config connection.ConnectionConfig, oldName string, dbType := resolveDDLDBType(config) switch dbType { - case "mysql": - return connection.QueryResult{Success: false, Message: "MySQL 不支持直接重命名数据库,请新建库后迁移数据"} - case "postgres", "kingbase": + case "mysql", "mariadb": + return connection.QueryResult{Success: false, Message: "MySQL/MariaDB 不支持直接重命名数据库,请新建库后迁移数据"} + case "postgres", "kingbase", "highgo", "vastbase": if strings.EqualFold(strings.TrimSpace(config.Database), oldName) { return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再重命名"} } @@ -173,11 +176,11 @@ func (a *App) DropDatabase(config connection.ConnectionConfig, dbName string) co sql string ) switch dbType { - case "mysql": + case "mysql", "mariadb": runConfig = config runConfig.Database = "" sql = fmt.Sprintf("DROP DATABASE %s", quoteIdentByType(dbType, dbName)) - case "postgres", "kingbase": + case "postgres", "kingbase", "highgo", "vastbase": if strings.EqualFold(strings.TrimSpace(config.Database), dbName) { return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再删除"} } @@ -215,7 +218,7 @@ func (a *App) RenameTable(config connection.ConnectionConfig, dbName string, old dbType := resolveDDLDBType(config) switch dbType { - case "mysql", "postgres", "kingbase", "sqlite", "oracle", "dameng": + case "mysql", "mariadb", "postgres", "kingbase", "sqlite", "oracle", "dameng", "highgo", "vastbase", "sqlserver": default: return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持重命名表", dbType)} } @@ -227,10 +230,19 @@ func (a *App) RenameTable(config connection.ConnectionConfig, dbName string, old oldQualifiedTable := quoteTableIdentByType(dbType, schemaName, pureOldTableName) newTableQuoted := quoteIdentByType(dbType, newTableName) - sql := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldQualifiedTable, newTableQuoted) - if dbType == "mysql" { + var sql string + switch dbType { + case "mysql", "mariadb": newQualifiedTable := quoteTableIdentByType(dbType, schemaName, newTableName) sql = fmt.Sprintf("RENAME TABLE %s TO %s", oldQualifiedTable, newQualifiedTable) + case "sqlserver": + // SQL Server 使用 sp_rename,参数为 'schema.oldname', 'newname' + oldFullName := schemaName + "." + pureOldTableName + escapedOld := strings.ReplaceAll(oldFullName, "'", "''") + escapedNew := strings.ReplaceAll(newTableName, "'", "''") + sql = fmt.Sprintf("EXEC sp_rename '%s', '%s'", escapedOld, escapedNew) + default: + sql = fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldQualifiedTable, newTableQuoted) } runConfig := buildRunConfigForDDL(config, dbType, dbName) @@ -252,7 +264,7 @@ func (a *App) DropTable(config connection.ConnectionConfig, dbName string, table dbType := resolveDDLDBType(config) switch dbType { - case "mysql", "postgres", "kingbase", "sqlite", "oracle", "dameng": + case "mysql", "mariadb", "postgres", "kingbase", "sqlite", "oracle", "dameng", "highgo", "vastbase", "sqlserver": default: return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持删除表", dbType)} } diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 0cdac78..e96537b 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -408,8 +408,11 @@ func quoteIdentByType(dbType string, ident string) string { } switch dbType { - case "mysql": + case "mysql", "mariadb": return "`" + strings.ReplaceAll(ident, "`", "``") + "`" + case "sqlserver": + escaped := strings.ReplaceAll(ident, "]", "]]") + return "[" + escaped + "]" default: return `"` + strings.ReplaceAll(ident, `"`, `""`) + `"` } diff --git a/internal/app/sql_sanitize.go b/internal/app/sql_sanitize.go index 4e37ed5..99c5335 100644 --- a/internal/app/sql_sanitize.go +++ b/internal/app/sql_sanitize.go @@ -7,7 +7,7 @@ import ( func sanitizeSQLForPgLike(dbType string, query string) string { switch strings.ToLower(strings.TrimSpace(dbType)) { - case "postgres", "kingbase": + case "postgres", "kingbase", "highgo", "vastbase": // 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。 // 这里做有限次数的迭代,直到输出不再变化。 out := query diff --git a/internal/db/database.go b/internal/db/database.go index f2a3f10..9c03ccc 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -40,6 +40,16 @@ func NewDatabase(dbType string) (Database, error) { return &DamengDB{}, nil case "kingbase": return &KingbaseDB{}, nil + case "mongodb": + return &MongoDB{}, nil + case "sqlserver": + return &SqlServerDB{}, nil + case "highgo": + return &HighGoDB{}, nil + case "mariadb": + return &MariaDB{}, nil + case "vastbase": + return &VastbaseDB{}, nil case "custom": return &CustomDB{}, nil default: diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go new file mode 100644 index 0000000..79f25ae --- /dev/null +++ b/internal/db/highgo_impl.go @@ -0,0 +1,628 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/ssh" + "GoNavi-Wails/internal/utils" + + _ "github.com/lib/pq" // HighGo is PostgreSQL compatible +) + +// HighGoDB implements Database interface for HighGo (瀚高) database +// HighGo is a PostgreSQL-compatible database, so we reuse PostgreSQL driver +type HighGoDB struct { + conn *sql.DB + pingTimeout time.Duration + forwarder *ssh.LocalForwarder +} + +func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string { + // postgres://user:password@host:port/dbname?sslmode=disable + dbname := config.Database + if dbname == "" { + dbname = "highgo" // HighGo default database + } + + u := &url.URL{ + Scheme: "postgres", + Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), + Path: "/" + dbname, + } + u.User = url.UserPassword(config.User, config.Password) + q := url.Values{} + q.Set("sslmode", "disable") + q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) + u.RawQuery = q.Encode() + + return u.String() +} + +func (h *HighGoDB) Connect(config connection.ConnectionConfig) error { + var dsn string + + if config.UseSSH { + logger.Infof("HighGo 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + h.forwarder = forwarder + + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + + dsn = h.getDSN(localConfig) + logger.Infof("HighGo 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + dsn = h.getDSN(config) + } + + db, err := sql.Open("postgres", dsn) + if err != nil { + return fmt.Errorf("打开数据库连接失败:%w", err) + } + h.conn = db + h.pingTimeout = getConnectTimeout(config) + + if err := h.Ping(); err != nil { + return fmt.Errorf("连接建立后验证失败:%w", err) + } + return nil +} + +func (h *HighGoDB) Close() error { + if h.forwarder != nil { + if err := h.forwarder.Close(); err != nil { + logger.Warnf("关闭 HighGo SSH 端口转发失败:%v", err) + } + h.forwarder = nil + } + + if h.conn != nil { + return h.conn.Close() + } + return nil +} + +func (h *HighGoDB) Ping() error { + if h.conn == nil { + return fmt.Errorf("connection not open") + } + timeout := h.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + return h.conn.PingContext(ctx) +} + +func (h *HighGoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if h.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := h.conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + return scanRows(rows) +} + +func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, error) { + if h.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := h.conn.Query(query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error) { + if h.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := h.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (h *HighGoDB) Exec(query string) (int64, error) { + if h.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := h.conn.Exec(query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (h *HighGoDB) GetDatabases() ([]string, error) { + data, _, err := h.Query("SELECT datname FROM pg_database WHERE datistemplate = false") + if err != nil { + return nil, err + } + var dbs []string + for _, row := range data { + if val, ok := row["datname"]; ok { + dbs = append(dbs, fmt.Sprintf("%v", val)) + } + } + return dbs, nil +} + +func (h *HighGoDB) GetTables(dbName string) ([]string, error) { + query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg_%' ORDER BY schemaname, tablename" + data, _, err := h.Query(query) + if err != nil { + return nil, err + } + + var tables []string + for _, row := range data { + schema, okSchema := row["schemaname"] + name, okName := row["tablename"] + if okSchema && okName { + tables = append(tables, fmt.Sprintf("%v.%v", schema, name)) + continue + } + if okName { + tables = append(tables, fmt.Sprintf("%v", name)) + } + } + return tables, nil +} + +func (h *HighGoDB) GetCreateStatement(dbName, tableName string) (string, error) { + return fmt.Sprintf("-- SHOW CREATE TABLE not fully supported for HighGo in this version.\n-- Table: %s", tableName), nil +} + +func (h *HighGoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, + pg_get_expr(ad.adbin, ad.adrelid) AS column_default, + col_description(a.attrelid, a.attnum) AS comment, + CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key +FROM pg_class c +JOIN pg_namespace n ON n.oid = c.relnamespace +JOIN pg_attribute a ON a.attrelid = c.oid +LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum +LEFT JOIN ( + SELECT i.indrelid, a3.attname + FROM pg_index i + JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey) + WHERE i.indisprimary +) pk ON pk.indrelid = c.oid AND pk.attname = a.attname +WHERE c.relkind IN ('r', 'p') + AND n.nspname = '%s' + AND c.relname = '%s' + AND a.attnum > 0 + AND NOT a.attisdropped +ORDER BY a.attnum`, esc(schema), esc(table)) + + data, _, err := h.Query(query) + if err != nil { + return nil, err + } + + var columns []connection.ColumnDefinition + for _, row := range data { + col := connection.ColumnDefinition{ + Name: fmt.Sprintf("%v", row["column_name"]), + Type: fmt.Sprintf("%v", row["data_type"]), + Nullable: fmt.Sprintf("%v", row["is_nullable"]), + Key: fmt.Sprintf("%v", row["column_key"]), + Extra: "", + Comment: "", + } + + if v, ok := row["comment"]; ok && v != nil { + col.Comment = fmt.Sprintf("%v", v) + } + + if v, ok := row["column_default"]; ok && v != nil { + def := fmt.Sprintf("%v", v) + col.Default = &def + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") { + col.Extra = "auto_increment" + } + } + + columns = append(columns, col) + } + return columns, nil +} + +func (h *HighGoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT + i.relname AS index_name, + a.attname AS column_name, + ix.indisunique AS is_unique, + x.ordinality AS seq_in_index, + am.amname AS index_type +FROM pg_class t +JOIN pg_namespace n ON n.oid = t.relnamespace +JOIN pg_index ix ON t.oid = ix.indrelid +JOIN pg_class i ON i.oid = ix.indexrelid +JOIN pg_am am ON i.relam = am.oid +JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE +JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum +WHERE t.relkind IN ('r', 'p') + AND t.relname = '%s' + AND n.nspname = '%s' +ORDER BY i.relname, x.ordinality`, esc(table), esc(schema)) + + data, _, err := h.Query(query) + if err != nil { + return nil, err + } + + parseBool := func(v interface{}) bool { + switch val := v.(type) { + case bool: + return val + case string: + s := strings.ToLower(strings.TrimSpace(val)) + return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes" + default: + s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes" + } + } + + parseInt := func(v interface{}) int { + switch val := v.(type) { + case int: + return val + case int64: + return int(val) + case float64: + return int(val) + case string: + var n int + _, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n) + return n + default: + var n int + _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n) + return n + } + } + + var indexes []connection.IndexDefinition + for _, row := range data { + isUnique := false + if v, ok := row["is_unique"]; ok && v != nil { + isUnique = parseBool(v) + } + + nonUnique := 1 + if isUnique { + nonUnique = 0 + } + + seq := 0 + if v, ok := row["seq_in_index"]; ok && v != nil { + seq = parseInt(v) + } + + indexType := "" + if v, ok := row["index_type"]; ok && v != nil { + indexType = strings.ToUpper(fmt.Sprintf("%v", v)) + } + if indexType == "" { + indexType = "BTREE" + } + + idx := connection.IndexDefinition{ + Name: fmt.Sprintf("%v", row["index_name"]), + ColumnName: fmt.Sprintf("%v", row["column_name"]), + NonUnique: nonUnique, + SeqInIndex: seq, + IndexType: indexType, + } + indexes = append(indexes, idx) + } + return indexes, nil +} + +func (h *HighGoDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT + tc.constraint_name AS constraint_name, + kcu.column_name AS column_name, + ccu.table_schema AS foreign_table_schema, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name +FROM information_schema.table_constraints AS tc +JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema +JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema +WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = '%s' + AND tc.table_schema = '%s' +ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema)) + + data, _, err := h.Query(query) + if err != nil { + return nil, err + } + + var fks []connection.ForeignKeyDefinition + for _, row := range data { + refSchema := "" + if v, ok := row["foreign_table_schema"]; ok && v != nil { + refSchema = fmt.Sprintf("%v", v) + } + refTable := fmt.Sprintf("%v", row["foreign_table_name"]) + refTableName := refTable + if strings.TrimSpace(refSchema) != "" { + refTableName = fmt.Sprintf("%s.%s", refSchema, refTable) + } + + fk := connection.ForeignKeyDefinition{ + Name: fmt.Sprintf("%v", row["constraint_name"]), + ColumnName: fmt.Sprintf("%v", row["column_name"]), + RefTableName: refTableName, + RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]), + ConstraintName: fmt.Sprintf("%v", row["constraint_name"]), + } + fks = append(fks, fk) + } + return fks, nil +} + +func (h *HighGoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT trigger_name, action_timing, event_manipulation, action_statement +FROM information_schema.triggers +WHERE event_object_table = '%s' + AND event_object_schema = '%s' +ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema)) + + data, _, err := h.Query(query) + if err != nil { + return nil, err + } + + var triggers []connection.TriggerDefinition + for _, row := range data { + trig := connection.TriggerDefinition{ + Name: fmt.Sprintf("%v", row["trigger_name"]), + Timing: fmt.Sprintf("%v", row["action_timing"]), + Event: fmt.Sprintf("%v", row["event_manipulation"]), + Statement: fmt.Sprintf("%v", row["action_statement"]), + } + triggers = append(triggers, trig) + } + return triggers, nil +} + +func (h *HighGoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + query := ` +SELECT table_schema, table_name, column_name, data_type +FROM information_schema.columns +WHERE table_schema NOT IN ('pg_catalog', 'information_schema') + AND table_schema NOT LIKE 'pg_%' +ORDER BY table_schema, table_name, ordinal_position` + + data, _, err := h.Query(query) + if err != nil { + return nil, err + } + + var cols []connection.ColumnDefinitionWithTable + for _, row := range data { + schema := fmt.Sprintf("%v", row["table_schema"]) + table := fmt.Sprintf("%v", row["table_name"]) + tableName := table + if strings.TrimSpace(schema) != "" { + tableName = fmt.Sprintf("%s.%s", schema, table) + } + + col := connection.ColumnDefinitionWithTable{ + TableName: tableName, + Name: fmt.Sprintf("%v", row["column_name"]), + Type: fmt.Sprintf("%v", row["data_type"]), + } + cols = append(cols, col) + } + return cols, nil +} + +func (h *HighGoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if h.conn == nil { + return fmt.Errorf("connection not open") + } + + tx, err := h.conn.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + quoteIdent := func(name string) string { + n := strings.TrimSpace(name) + n = strings.Trim(n, "\"") + n = strings.ReplaceAll(n, "\"", "\"\"") + if n == "" { + return "\"\"" + } + return `"` + n + `"` + } + + schema := "" + table := strings.TrimSpace(tableName) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + qualifiedTable := "" + if schema != "" { + qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) + } else { + qualifiedTable = quoteIdent(table) + } + + // 1. Deletes + for _, pk := range changes.Deletes { + var wheres []string + var args []interface{} + idx := 0 + for k, v := range pk { + idx++ + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + args = append(args, v) + } + if len(wheres) == 0 { + continue + } + query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("delete error: %v", err) + } + } + + // 2. Updates + for _, update := range changes.Updates { + var sets []string + var args []interface{} + idx := 0 + + for k, v := range update.Values { + idx++ + sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + args = append(args, v) + } + + if len(sets) == 0 { + continue + } + + var wheres []string + for k, v := range update.Keys { + idx++ + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + args = append(args, v) + } + + if len(wheres) == 0 { + return fmt.Errorf("update requires keys") + } + + query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("update error: %v", err) + } + } + + // 3. Inserts + for _, row := range changes.Inserts { + var cols []string + var placeholders []string + var args []interface{} + idx := 0 + + for k, v := range row { + idx++ + cols = append(cols, quoteIdent(k)) + placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) + args = append(args, v) + } + + if len(cols) == 0 { + continue + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("insert error: %v", err) + } + } + + return tx.Commit() +} diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go new file mode 100644 index 0000000..5559b4e --- /dev/null +++ b/internal/db/mariadb_impl.go @@ -0,0 +1,409 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/ssh" + "GoNavi-Wails/internal/utils" + + _ "github.com/go-sql-driver/mysql" +) + +// MariaDB implements Database interface for MariaDB +// MariaDB is MySQL-compatible, so we reuse the MySQL driver +type MariaDB struct { + conn *sql.DB + pingTimeout time.Duration +} + +func (m *MariaDB) getDSN(config connection.ConnectionConfig) string { + database := config.Database + protocol := "tcp" + address := fmt.Sprintf("%s:%d", config.Host, config.Port) + + if config.UseSSH { + netName, err := ssh.RegisterSSHNetwork(config.SSH) + if err == nil { + protocol = netName + address = fmt.Sprintf("%s:%d", config.Host, config.Port) + } else { + logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err) + } + } + + timeout := getConnectTimeoutSeconds(config) + + return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds", + config.User, config.Password, protocol, address, database, timeout) +} + +func (m *MariaDB) Connect(config connection.ConnectionConfig) error { + dsn := m.getDSN(config) + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("打开数据库连接失败:%w", err) + } + m.conn = db + m.pingTimeout = getConnectTimeout(config) + + if err := m.Ping(); err != nil { + return fmt.Errorf("连接建立后验证失败:%w", err) + } + return nil +} + +func (m *MariaDB) Close() error { + if m.conn != nil { + return m.conn.Close() + } + return nil +} + +func (m *MariaDB) Ping() error { + if m.conn == nil { + return fmt.Errorf("connection not open") + } + timeout := m.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + return m.conn.PingContext(ctx) +} + +func (m *MariaDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if m.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := m.conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + return scanRows(rows) +} + +func (m *MariaDB) Query(query string) ([]map[string]interface{}, []string, error) { + if m.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := m.conn.Query(query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (m *MariaDB) ExecContext(ctx context.Context, query string) (int64, error) { + if m.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := m.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (m *MariaDB) Exec(query string) (int64, error) { + if m.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := m.conn.Exec(query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (m *MariaDB) GetDatabases() ([]string, error) { + data, _, err := m.Query("SHOW DATABASES") + if err != nil { + return nil, err + } + var dbs []string + for _, row := range data { + if val, ok := row["Database"]; ok { + dbs = append(dbs, fmt.Sprintf("%v", val)) + } else if val, ok := row["database"]; ok { + dbs = append(dbs, fmt.Sprintf("%v", val)) + } + } + return dbs, nil +} + +func (m *MariaDB) GetTables(dbName string) ([]string, error) { + query := "SHOW TABLES" + if dbName != "" { + query = fmt.Sprintf("SHOW TABLES FROM `%s`", dbName) + } + + data, _, err := m.Query(query) + if err != nil { + return nil, err + } + + var tables []string + for _, row := range data { + for _, v := range row { + tables = append(tables, fmt.Sprintf("%v", v)) + break + } + } + return tables, nil +} + +func (m *MariaDB) GetCreateStatement(dbName, tableName string) (string, error) { + query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", dbName, tableName) + if dbName == "" { + query = fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName) + } + + data, _, err := m.Query(query) + if err != nil { + return "", err + } + + if len(data) > 0 { + if val, ok := data[0]["Create Table"]; ok { + return fmt.Sprintf("%v", val), nil + } + } + return "", fmt.Errorf("create statement not found") +} + +func (m *MariaDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName) + if dbName == "" { + query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName) + } + + data, _, err := m.Query(query) + if err != nil { + return nil, err + } + + var columns []connection.ColumnDefinition + for _, row := range data { + col := connection.ColumnDefinition{ + Name: fmt.Sprintf("%v", row["Field"]), + Type: fmt.Sprintf("%v", row["Type"]), + Nullable: fmt.Sprintf("%v", row["Null"]), + Key: fmt.Sprintf("%v", row["Key"]), + Extra: fmt.Sprintf("%v", row["Extra"]), + Comment: fmt.Sprintf("%v", row["Comment"]), + } + + if row["Default"] != nil { + d := fmt.Sprintf("%v", row["Default"]) + col.Default = &d + } + + columns = append(columns, col) + } + return columns, nil +} + +func (m *MariaDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", dbName, tableName) + if dbName == "" { + query = fmt.Sprintf("SHOW INDEX FROM `%s`", tableName) + } + + data, _, err := m.Query(query) + if err != nil { + return nil, err + } + + var indexes []connection.IndexDefinition + for _, row := range data { + nonUnique := 0 + if val, ok := row["Non_unique"]; ok { + if f, ok := val.(float64); ok { + nonUnique = int(f) + } else if i, ok := val.(int64); ok { + nonUnique = int(i) + } + } + + seq := 0 + if val, ok := row["Seq_in_index"]; ok { + if f, ok := val.(float64); ok { + seq = int(f) + } else if i, ok := val.(int64); ok { + seq = int(i) + } + } + + idx := connection.IndexDefinition{ + Name: fmt.Sprintf("%v", row["Key_name"]), + ColumnName: fmt.Sprintf("%v", row["Column_name"]), + NonUnique: nonUnique, + SeqInIndex: seq, + IndexType: fmt.Sprintf("%v", row["Index_type"]), + } + indexes = append(indexes, idx) + } + return indexes, nil +} + +func (m *MariaDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + query := fmt.Sprintf(`SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME + FROM information_schema.KEY_COLUMN_USAGE + WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND REFERENCED_TABLE_NAME IS NOT NULL`, dbName, tableName) + + data, _, err := m.Query(query) + if err != nil { + return nil, err + } + + var fks []connection.ForeignKeyDefinition + for _, row := range data { + fk := connection.ForeignKeyDefinition{ + Name: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]), + ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]), + RefTableName: fmt.Sprintf("%v", row["REFERENCED_TABLE_NAME"]), + RefColumnName: fmt.Sprintf("%v", row["REFERENCED_COLUMN_NAME"]), + ConstraintName: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]), + } + fks = append(fks, fk) + } + return fks, nil +} + +func (m *MariaDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + query := fmt.Sprintf("SHOW TRIGGERS FROM `%s` WHERE `Table` = '%s'", dbName, tableName) + data, _, err := m.Query(query) + if err != nil { + return nil, err + } + + var triggers []connection.TriggerDefinition + for _, row := range data { + trig := connection.TriggerDefinition{ + Name: fmt.Sprintf("%v", row["Trigger"]), + Timing: fmt.Sprintf("%v", row["Timing"]), + Event: fmt.Sprintf("%v", row["Event"]), + Statement: fmt.Sprintf("%v", row["Statement"]), + } + triggers = append(triggers, trig) + } + return triggers, nil +} + +func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if m.conn == nil { + return fmt.Errorf("connection not open") + } + + tx, err := m.conn.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // 1. Deletes + for _, pk := range changes.Deletes { + var wheres []string + var args []interface{} + for k, v := range pk { + wheres = append(wheres, fmt.Sprintf("`%s` = ?", k)) + args = append(args, normalizeMySQLDateTimeValue(v)) + } + if len(wheres) == 0 { + continue + } + query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("delete error: %v", err) + } + } + + // 2. Updates + for _, update := range changes.Updates { + var sets []string + var args []interface{} + + for k, v := range update.Values { + sets = append(sets, fmt.Sprintf("`%s` = ?", k)) + args = append(args, normalizeMySQLDateTimeValue(v)) + } + + if len(sets) == 0 { + continue + } + + var wheres []string + for k, v := range update.Keys { + wheres = append(wheres, fmt.Sprintf("`%s` = ?", k)) + args = append(args, normalizeMySQLDateTimeValue(v)) + } + + if len(wheres) == 0 { + return fmt.Errorf("update requires keys") + } + + query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("update error: %v", err) + } + } + + // 3. Inserts + for _, row := range changes.Inserts { + var cols []string + var placeholders []string + var args []interface{} + + for k, v := range row { + cols = append(cols, fmt.Sprintf("`%s`", k)) + placeholders = append(placeholders, "?") + args = append(args, normalizeMySQLDateTimeValue(v)) + } + + if len(cols) == 0 { + continue + } + + query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("insert error: %v", err) + } + } + + return tx.Commit() +} + +func (m *MariaDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName) + if dbName == "" { + return nil, fmt.Errorf("database name required for GetAllColumns") + } + + data, _, err := m.Query(query) + if err != nil { + return nil, err + } + + var cols []connection.ColumnDefinitionWithTable + for _, row := range data { + col := connection.ColumnDefinitionWithTable{ + TableName: fmt.Sprintf("%v", row["TABLE_NAME"]), + Name: fmt.Sprintf("%v", row["COLUMN_NAME"]), + Type: fmt.Sprintf("%v", row["COLUMN_TYPE"]), + } + cols = append(cols, col) + } + return cols, nil +} diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go new file mode 100644 index 0000000..4f1f300 --- /dev/null +++ b/internal/db/mongodb_impl.go @@ -0,0 +1,407 @@ +package db + +import ( + "context" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/ssh" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/readpref" +) + +type MongoDB struct { + client *mongo.Client + database string + pingTimeout time.Duration + forwarder *ssh.LocalForwarder +} + +func (m *MongoDB) getURI(config connection.ConnectionConfig) string { + // mongodb://user:password@host:port/database?authSource=admin + host := config.Host + port := config.Port + if port == 0 { + port = 27017 + } + + uri := fmt.Sprintf("mongodb://%s:%d", host, port) + + if config.User != "" { + encodedUser := url.QueryEscape(config.User) + if config.Password != "" { + encodedPass := url.QueryEscape(config.Password) + uri = fmt.Sprintf("mongodb://%s:%s@%s:%d", encodedUser, encodedPass, host, port) + } else { + uri = fmt.Sprintf("mongodb://%s@%s:%d", encodedUser, host, port) + } + } + + // Add connection options + params := []string{} + timeout := getConnectTimeoutSeconds(config) + params = append(params, fmt.Sprintf("connectTimeoutMS=%d", timeout*1000)) + params = append(params, fmt.Sprintf("serverSelectionTimeoutMS=%d", timeout*1000)) + + // authSource: 优先使用 config.Database,为空时默认 admin + authSource := "admin" + if config.Database != "" { + authSource = config.Database + } + params = append(params, fmt.Sprintf("authSource=%s", authSource)) + + if len(params) > 0 { + uri = uri + "/?" + strings.Join(params, "&") + } + + return uri +} + +func (m *MongoDB) Connect(config connection.ConnectionConfig) error { + var uri string + + if config.UseSSH { + logger.Infof("MongoDB 使用 SSH 连接:地址=%s:%d", config.Host, config.Port) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + m.forwarder = forwarder + + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + + uri = m.getURI(localConfig) + logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + uri = m.getURI(config) + } + + m.pingTimeout = getConnectTimeout(config) + m.database = config.Database + if m.database == "" { + m.database = "admin" + } + + clientOpts := options.Client().ApplyURI(uri) + client, err := mongo.Connect(clientOpts) + if err != nil { + return fmt.Errorf("MongoDB 连接失败:%w", err) + } + m.client = client + + if err := m.Ping(); err != nil { + return fmt.Errorf("MongoDB 连接验证失败:%w", err) + } + + return nil +} + +func (m *MongoDB) Close() error { + if m.forwarder != nil { + if err := m.forwarder.Close(); err != nil { + logger.Warnf("关闭 MongoDB SSH 端口转发失败:%v", err) + } + m.forwarder = nil + } + + if m.client != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return m.client.Disconnect(ctx) + } + return nil +} + +func (m *MongoDB) Ping() error { + if m.client == nil { + return fmt.Errorf("connection not open") + } + timeout := m.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return m.client.Ping(ctx, readpref.Primary()) +} + +// Query executes a MongoDB command and returns results +// Supports JSON format commands like: {"find": "collection", "filter": {}} +func (m *MongoDB) Query(query string) ([]map[string]interface{}, []string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + return m.queryWithContext(ctx, query) +} + +// QueryContext executes a MongoDB command with the given context for timeout control +func (m *MongoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + return m.queryWithContext(ctx, query) +} + +func (m *MongoDB) queryWithContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if m.client == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + query = strings.TrimSpace(query) + if query == "" { + return nil, nil, fmt.Errorf("empty query") + } + + // Parse JSON command + var cmd bson.D + if err := bson.UnmarshalExtJSON([]byte(query), true, &cmd); err != nil { + return nil, nil, fmt.Errorf("invalid JSON command: %w", err) + } + + db := m.client.Database(m.database) + var result bson.M + if err := db.RunCommand(ctx, cmd).Decode(&result); err != nil { + return nil, nil, err + } + + // Convert result to standard format + data := []map[string]interface{}{{"result": result}} + columns := []string{"result"} + + // If result contains cursor with documents, extract them + if cursor, ok := result["cursor"].(bson.M); ok { + if batch, ok := cursor["firstBatch"].(bson.A); ok { + data = make([]map[string]interface{}, 0, len(batch)) + columnSet := make(map[string]bool) + for _, doc := range batch { + if docMap, ok := doc.(bson.M); ok { + row := make(map[string]interface{}) + for k, v := range docMap { + row[k] = v + columnSet[k] = true + } + data = append(data, row) + } + } + columns = make([]string, 0, len(columnSet)) + for k := range columnSet { + columns = append(columns, k) + } + } + } + + return data, columns, nil +} + +func (m *MongoDB) Exec(query string) (int64, error) { + _, _, err := m.Query(query) + if err != nil { + return 0, err + } + return 1, nil +} + +// ExecContext executes a MongoDB command with the given context for timeout control +func (m *MongoDB) ExecContext(ctx context.Context, query string) (int64, error) { + _, _, err := m.QueryContext(ctx, query) + if err != nil { + return 0, err + } + return 1, nil +} + +func (m *MongoDB) GetDatabases() ([]string, error) { + if m.client == nil { + return nil, fmt.Errorf("connection not open") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + dbs, err := m.client.ListDatabaseNames(ctx, bson.M{}) + if err != nil { + return nil, err + } + return dbs, nil +} + +func (m *MongoDB) GetTables(dbName string) ([]string, error) { + if m.client == nil { + return nil, fmt.Errorf("connection not open") + } + + targetDB := dbName + if targetDB == "" { + targetDB = m.database + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + collections, err := m.client.Database(targetDB).ListCollectionNames(ctx, bson.M{}) + if err != nil { + return nil, err + } + return collections, nil +} + +func (m *MongoDB) GetCreateStatement(dbName, tableName string) (string, error) { + return fmt.Sprintf("// MongoDB collection: %s.%s\n// MongoDB is schemaless - no CREATE statement available", dbName, tableName), nil +} + +// GetColumns returns empty for MongoDB (schemaless) +func (m *MongoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + // MongoDB is schemaless, return empty + return []connection.ColumnDefinition{}, nil +} + +// GetAllColumns returns empty for MongoDB (schemaless) +func (m *MongoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + return []connection.ColumnDefinitionWithTable{}, nil +} + +// GetIndexes returns indexes for a MongoDB collection +func (m *MongoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + if m.client == nil { + return nil, fmt.Errorf("connection not open") + } + + targetDB := dbName + if targetDB == "" { + targetDB = m.database + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + collection := m.client.Database(targetDB).Collection(tableName) + cursor, err := collection.Indexes().List(ctx) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var indexes []connection.IndexDefinition + for cursor.Next(ctx) { + var idx bson.M + if err := cursor.Decode(&idx); err != nil { + continue + } + + name := fmt.Sprintf("%v", idx["name"]) + unique := false + if u, ok := idx["unique"].(bool); ok { + unique = u + } + + // Extract key fields + if key, ok := idx["key"].(bson.M); ok { + seq := 1 + for field := range key { + nonUnique := 1 + if unique { + nonUnique = 0 + } + indexes = append(indexes, connection.IndexDefinition{ + Name: name, + ColumnName: field, + NonUnique: nonUnique, + SeqInIndex: seq, + IndexType: "BTREE", + }) + seq++ + } + } + } + + return indexes, nil +} + +func (m *MongoDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + // MongoDB doesn't have foreign keys + return []connection.ForeignKeyDefinition{}, nil +} + +func (m *MongoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + // MongoDB doesn't have triggers in the traditional sense + return []connection.TriggerDefinition{}, nil +} + +// ApplyChanges implements batch changes for MongoDB +func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if m.client == nil { + return fmt.Errorf("connection not open") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + collection := m.client.Database(m.database).Collection(tableName) + + // Process deletes + for _, pk := range changes.Deletes { + filter := bson.M{} + for k, v := range pk { + filter[k] = v + } + if len(filter) > 0 { + if _, err := collection.DeleteOne(ctx, filter); err != nil { + return fmt.Errorf("delete error: %v", err) + } + } + } + + // Process updates + for _, update := range changes.Updates { + filter := bson.M{} + for k, v := range update.Keys { + filter[k] = v + } + if len(filter) == 0 { + return fmt.Errorf("update requires keys") + } + + updateDoc := bson.M{"$set": bson.M{}} + for k, v := range update.Values { + updateDoc["$set"].(bson.M)[k] = v + } + + if _, err := collection.UpdateOne(ctx, filter, updateDoc); err != nil { + return fmt.Errorf("update error: %v", err) + } + } + + // Process inserts + for _, row := range changes.Inserts { + doc := bson.M{} + for k, v := range row { + doc[k] = v + } + if len(doc) > 0 { + if _, err := collection.InsertOne(ctx, doc); err != nil { + return fmt.Errorf("insert error: %v", err) + } + } + } + + return nil +} diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go new file mode 100644 index 0000000..145f6e1 --- /dev/null +++ b/internal/db/sqlserver_impl.go @@ -0,0 +1,635 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/ssh" + "GoNavi-Wails/internal/utils" + + _ "github.com/microsoft/go-mssqldb" +) + +type SqlServerDB struct { + conn *sql.DB + pingTimeout time.Duration + forwarder *ssh.LocalForwarder +} + +// quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation +func quoteBracket(name string) string { + return strings.ReplaceAll(name, "]", "]]") +} + +func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string { + // sqlserver://user:password@host:port?database=dbname + dbname := config.Database + if dbname == "" { + dbname = "master" + } + + u := &url.URL{ + Scheme: "sqlserver", + Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), + } + u.User = url.UserPassword(config.User, config.Password) + + q := url.Values{} + q.Set("database", dbname) + q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) + q.Set("encrypt", "disable") + q.Set("TrustServerCertificate", "true") + u.RawQuery = q.Encode() + + return u.String() +} + +func (s *SqlServerDB) Connect(config connection.ConnectionConfig) error { + var dsn string + + if config.UseSSH { + logger.Infof("SQL Server 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + s.forwarder = forwarder + + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + + dsn = s.getDSN(localConfig) + logger.Infof("SQL Server 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + dsn = s.getDSN(config) + } + + db, err := sql.Open("sqlserver", dsn) + if err != nil { + return fmt.Errorf("打开数据库连接失败:%w", err) + } + s.conn = db + s.pingTimeout = getConnectTimeout(config) + + if err := s.Ping(); err != nil { + return fmt.Errorf("连接建立后验证失败:%w", err) + } + return nil +} + +func (s *SqlServerDB) Close() error { + if s.forwarder != nil { + if err := s.forwarder.Close(); err != nil { + logger.Warnf("关闭 SQL Server SSH 端口转发失败:%v", err) + } + s.forwarder = nil + } + + if s.conn != nil { + return s.conn.Close() + } + return nil +} + +func (s *SqlServerDB) Ping() error { + if s.conn == nil { + return fmt.Errorf("connection not open") + } + timeout := s.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + return s.conn.PingContext(ctx) +} + +func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if s.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := s.conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + return scanRows(rows) +} + +func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) { + if s.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := s.conn.Query(query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) { + if s.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := s.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (s *SqlServerDB) Exec(query string) (int64, error) { + if s.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := s.conn.Exec(query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (s *SqlServerDB) GetDatabases() ([]string, error) { + query := "SELECT name FROM sys.databases WHERE state_desc = 'ONLINE' ORDER BY name" + data, _, err := s.Query(query) + if err != nil { + return nil, err + } + var dbs []string + for _, row := range data { + if val, ok := row["name"]; ok { + dbs = append(dbs, fmt.Sprintf("%v", val)) + } + } + return dbs, nil +} + +func (s *SqlServerDB) GetTables(dbName string) ([]string, error) { + // SQL Server uses schema.table format, default schema is dbo + safeDB := quoteBracket(dbName) + query := fmt.Sprintf(` +SELECT s.name AS schema_name, t.name AS table_name +FROM [%s].sys.tables t +JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id +WHERE t.type = 'U' +ORDER BY s.name, t.name`, safeDB, safeDB) + + data, _, err := s.Query(query) + if err != nil { + return nil, err + } + + var tables []string + for _, row := range data { + schema, okSchema := row["schema_name"] + name, okName := row["table_name"] + if okSchema && okName { + tables = append(tables, fmt.Sprintf("%v.%v", schema, name)) + continue + } + if okName { + tables = append(tables, fmt.Sprintf("%v", name)) + } + } + return tables, nil +} + +func (s *SqlServerDB) GetCreateStatement(dbName, tableName string) (string, error) { + return fmt.Sprintf("-- SHOW CREATE TABLE not supported for SQL Server in this version.\n-- Table: %s.%s", dbName, tableName), nil +} + +func (s *SqlServerDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + schema := "dbo" + table := strings.TrimSpace(tableName) + + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + safeDB := quoteBracket(dbName) + + query := fmt.Sprintf(` +SELECT + c.name AS column_name, + t.name + CASE + WHEN t.name IN ('varchar', 'nvarchar', 'char', 'nchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(CASE WHEN t.name IN ('nvarchar', 'nchar') THEN c.max_length / 2 ELSE c.max_length END AS VARCHAR) END + ')' + WHEN t.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR) + ',' + CAST(c.scale AS VARCHAR) + ')' + ELSE '' + END AS data_type, + CASE WHEN c.is_nullable = 1 THEN 'YES' ELSE 'NO' END AS is_nullable, + dc.definition AS column_default, + ep.value AS comment, + CASE WHEN pk.column_id IS NOT NULL THEN 'PRI' ELSE '' END AS column_key, + CASE WHEN c.is_identity = 1 THEN 'auto_increment' ELSE '' END AS extra +FROM [%s].sys.columns c +JOIN [%s].sys.types t ON c.user_type_id = t.user_type_id +JOIN [%s].sys.tables tb ON c.object_id = tb.object_id +JOIN [%s].sys.schemas s ON tb.schema_id = s.schema_id +LEFT JOIN [%s].sys.default_constraints dc ON c.default_object_id = dc.object_id +LEFT JOIN [%s].sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id AND ep.name = 'MS_Description' +LEFT JOIN ( + SELECT ic.object_id, ic.column_id + FROM [%s].sys.index_columns ic + JOIN [%s].sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id + WHERE i.is_primary_key = 1 +) pk ON pk.object_id = c.object_id AND pk.column_id = c.column_id +WHERE s.name = '%s' AND tb.name = '%s' +ORDER BY c.column_id`, + safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, + esc(schema), esc(table)) + + data, _, err := s.Query(query) + if err != nil { + return nil, err + } + + var columns []connection.ColumnDefinition + for _, row := range data { + col := connection.ColumnDefinition{ + Name: fmt.Sprintf("%v", row["column_name"]), + Type: fmt.Sprintf("%v", row["data_type"]), + Nullable: fmt.Sprintf("%v", row["is_nullable"]), + Key: fmt.Sprintf("%v", row["column_key"]), + Extra: fmt.Sprintf("%v", row["extra"]), + Comment: "", + } + + if v, ok := row["comment"]; ok && v != nil { + col.Comment = fmt.Sprintf("%v", v) + } + + if v, ok := row["column_default"]; ok && v != nil { + def := fmt.Sprintf("%v", v) + col.Default = &def + } + + columns = append(columns, col) + } + return columns, nil +} + +func (s *SqlServerDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + safeDB := quoteBracket(dbName) + query := fmt.Sprintf(` +SELECT s.name AS schema_name, t.name AS table_name, c.name AS column_name, tp.name AS data_type +FROM [%s].sys.columns c +JOIN [%s].sys.tables t ON c.object_id = t.object_id +JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id +JOIN [%s].sys.types tp ON c.user_type_id = tp.user_type_id +WHERE t.type = 'U' +ORDER BY s.name, t.name, c.column_id`, safeDB, safeDB, safeDB, safeDB) + + data, _, err := s.Query(query) + if err != nil { + return nil, err + } + + var cols []connection.ColumnDefinitionWithTable + for _, row := range data { + schema := fmt.Sprintf("%v", row["schema_name"]) + table := fmt.Sprintf("%v", row["table_name"]) + tableName := fmt.Sprintf("%s.%s", schema, table) + + col := connection.ColumnDefinitionWithTable{ + TableName: tableName, + Name: fmt.Sprintf("%v", row["column_name"]), + Type: fmt.Sprintf("%v", row["data_type"]), + } + cols = append(cols, col) + } + return cols, nil +} + +func (s *SqlServerDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + schema := "dbo" + table := strings.TrimSpace(tableName) + + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + safeDB := quoteBracket(dbName) + + query := fmt.Sprintf(` +SELECT + i.name AS index_name, + c.name AS column_name, + i.is_unique, + ic.key_ordinal AS seq_in_index, + i.type_desc AS index_type +FROM [%s].sys.indexes i +JOIN [%s].sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id +JOIN [%s].sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id +JOIN [%s].sys.tables t ON i.object_id = t.object_id +JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id +WHERE s.name = '%s' AND t.name = '%s' AND i.name IS NOT NULL +ORDER BY i.name, ic.key_ordinal`, + safeDB, safeDB, safeDB, safeDB, safeDB, esc(schema), esc(table)) + + data, _, err := s.Query(query) + if err != nil { + return nil, err + } + + var indexes []connection.IndexDefinition + for _, row := range data { + isUnique := false + if v, ok := row["is_unique"]; ok && v != nil { + switch val := v.(type) { + case bool: + isUnique = val + case int64: + isUnique = val == 1 + } + } + + nonUnique := 1 + if isUnique { + nonUnique = 0 + } + + seq := 0 + if v, ok := row["seq_in_index"]; ok && v != nil { + switch val := v.(type) { + case int: + seq = val + case int64: + seq = int(val) + } + } + + indexType := "NONCLUSTERED" + if v, ok := row["index_type"]; ok && v != nil { + indexType = strings.ToUpper(fmt.Sprintf("%v", v)) + } + + idx := connection.IndexDefinition{ + Name: fmt.Sprintf("%v", row["index_name"]), + ColumnName: fmt.Sprintf("%v", row["column_name"]), + NonUnique: nonUnique, + SeqInIndex: seq, + IndexType: indexType, + } + indexes = append(indexes, idx) + } + return indexes, nil +} + +func (s *SqlServerDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + schema := "dbo" + table := strings.TrimSpace(tableName) + + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + safeDB := quoteBracket(dbName) + + query := fmt.Sprintf(` +SELECT + fk.name AS constraint_name, + c.name AS column_name, + rs.name AS foreign_schema, + rt.name AS foreign_table, + rc.name AS foreign_column +FROM [%s].sys.foreign_keys fk +JOIN [%s].sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id +JOIN [%s].sys.columns c ON fkc.parent_object_id = c.object_id AND fkc.parent_column_id = c.column_id +JOIN [%s].sys.tables t ON fk.parent_object_id = t.object_id +JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id +JOIN [%s].sys.tables rt ON fk.referenced_object_id = rt.object_id +JOIN [%s].sys.schemas rs ON rt.schema_id = rs.schema_id +JOIN [%s].sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id +WHERE s.name = '%s' AND t.name = '%s' +ORDER BY fk.name`, + safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, esc(schema), esc(table)) + + data, _, err := s.Query(query) + if err != nil { + return nil, err + } + + var fks []connection.ForeignKeyDefinition + for _, row := range data { + refSchema := fmt.Sprintf("%v", row["foreign_schema"]) + refTable := fmt.Sprintf("%v", row["foreign_table"]) + refTableName := fmt.Sprintf("%s.%s", refSchema, refTable) + + fk := connection.ForeignKeyDefinition{ + Name: fmt.Sprintf("%v", row["constraint_name"]), + ColumnName: fmt.Sprintf("%v", row["column_name"]), + RefTableName: refTableName, + RefColumnName: fmt.Sprintf("%v", row["foreign_column"]), + ConstraintName: fmt.Sprintf("%v", row["constraint_name"]), + } + fks = append(fks, fk) + } + return fks, nil +} + +func (s *SqlServerDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + schema := "dbo" + table := strings.TrimSpace(tableName) + + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + safeDB := quoteBracket(dbName) + + query := fmt.Sprintf(` +SELECT + tr.name AS trigger_name, + CASE WHEN tr.is_instead_of_trigger = 1 THEN 'INSTEAD OF' ELSE 'AFTER' END AS timing, + STUFF(( + SELECT ', ' + te.type_desc + FROM [%s].sys.trigger_events te + WHERE te.object_id = tr.object_id + FOR XML PATH('') + ), 1, 2, '') AS event, + OBJECT_DEFINITION(tr.object_id) AS statement +FROM [%s].sys.triggers tr +JOIN [%s].sys.tables t ON tr.parent_id = t.object_id +JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id +WHERE s.name = '%s' AND t.name = '%s' +ORDER BY tr.name`, + safeDB, safeDB, safeDB, safeDB, esc(schema), esc(table)) + + data, _, err := s.Query(query) + if err != nil { + return nil, err + } + + var triggers []connection.TriggerDefinition + for _, row := range data { + trig := connection.TriggerDefinition{ + Name: fmt.Sprintf("%v", row["trigger_name"]), + Timing: fmt.Sprintf("%v", row["timing"]), + Event: fmt.Sprintf("%v", row["event"]), + Statement: "", + } + if v, ok := row["statement"]; ok && v != nil { + trig.Statement = fmt.Sprintf("%v", v) + } + triggers = append(triggers, trig) + } + return triggers, nil +} + +func (s *SqlServerDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if s.conn == nil { + return fmt.Errorf("connection not open") + } + + tx, err := s.conn.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + quoteIdent := func(name string) string { + n := strings.TrimSpace(name) + n = strings.Trim(n, "[]") + n = strings.ReplaceAll(n, "]", "]]") + if n == "" { + return "[]" + } + return "[" + n + "]" + } + + schema := "dbo" + table := strings.TrimSpace(tableName) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + qualifiedTable := fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) + + // 1. Deletes + for _, pk := range changes.Deletes { + var wheres []string + var args []interface{} + idx := 0 + for k, v := range pk { + idx++ + wheres = append(wheres, fmt.Sprintf("%s = @p%d", quoteIdent(k), idx)) + args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v)) + } + if len(wheres) == 0 { + continue + } + query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("delete error: %v", err) + } + } + + // 2. Updates + for _, update := range changes.Updates { + var sets []string + var args []interface{} + idx := 0 + + for k, v := range update.Values { + idx++ + sets = append(sets, fmt.Sprintf("%s = @p%d", quoteIdent(k), idx)) + args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v)) + } + + if len(sets) == 0 { + continue + } + + var wheres []string + for k, v := range update.Keys { + idx++ + wheres = append(wheres, fmt.Sprintf("%s = @p%d", quoteIdent(k), idx)) + args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v)) + } + + if len(wheres) == 0 { + return fmt.Errorf("update requires keys") + } + + query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("update error: %v", err) + } + } + + // 3. Inserts + for _, row := range changes.Inserts { + var cols []string + var placeholders []string + var args []interface{} + idx := 0 + + for k, v := range row { + idx++ + cols = append(cols, quoteIdent(k)) + placeholders = append(placeholders, fmt.Sprintf("@p%d", idx)) + args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v)) + } + + if len(cols) == 0 { + continue + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("insert error: %v", err) + } + } + + return tx.Commit() +} diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go new file mode 100644 index 0000000..8f7b9c6 --- /dev/null +++ b/internal/db/vastbase_impl.go @@ -0,0 +1,627 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/ssh" + "GoNavi-Wails/internal/utils" + + _ "github.com/lib/pq" // Vastbase is PostgreSQL compatible +) + +// VastbaseDB implements Database interface for Vastbase (海量) database +// Vastbase is a PostgreSQL-compatible database, so we reuse PostgreSQL driver +type VastbaseDB struct { + conn *sql.DB + pingTimeout time.Duration + forwarder *ssh.LocalForwarder +} + +func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string { + dbname := config.Database + if dbname == "" { + dbname = "vastbase" // Vastbase default database + } + + u := &url.URL{ + Scheme: "postgres", + Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), + Path: "/" + dbname, + } + u.User = url.UserPassword(config.User, config.Password) + q := url.Values{} + q.Set("sslmode", "disable") + q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) + u.RawQuery = q.Encode() + + return u.String() +} + +func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error { + var dsn string + + if config.UseSSH { + logger.Infof("Vastbase 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + v.forwarder = forwarder + + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + + dsn = v.getDSN(localConfig) + logger.Infof("Vastbase 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + dsn = v.getDSN(config) + } + + db, err := sql.Open("postgres", dsn) + if err != nil { + return fmt.Errorf("打开数据库连接失败:%w", err) + } + v.conn = db + v.pingTimeout = getConnectTimeout(config) + + if err := v.Ping(); err != nil { + return fmt.Errorf("连接建立后验证失败:%w", err) + } + return nil +} + +func (v *VastbaseDB) Close() error { + if v.forwarder != nil { + if err := v.forwarder.Close(); err != nil { + logger.Warnf("关闭 Vastbase SSH 端口转发失败:%v", err) + } + v.forwarder = nil + } + + if v.conn != nil { + return v.conn.Close() + } + return nil +} + +func (v *VastbaseDB) Ping() error { + if v.conn == nil { + return fmt.Errorf("connection not open") + } + timeout := v.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + return v.conn.PingContext(ctx) +} + +func (v *VastbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if v.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := v.conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + return scanRows(rows) +} + +func (v *VastbaseDB) Query(query string) ([]map[string]interface{}, []string, error) { + if v.conn == nil { + return nil, nil, fmt.Errorf("connection not open") + } + + rows, err := v.conn.Query(query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (v *VastbaseDB) ExecContext(ctx context.Context, query string) (int64, error) { + if v.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := v.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (v *VastbaseDB) Exec(query string) (int64, error) { + if v.conn == nil { + return 0, fmt.Errorf("connection not open") + } + res, err := v.conn.Exec(query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (v *VastbaseDB) GetDatabases() ([]string, error) { + data, _, err := v.Query("SELECT datname FROM pg_database WHERE datistemplate = false") + if err != nil { + return nil, err + } + var dbs []string + for _, row := range data { + if val, ok := row["datname"]; ok { + dbs = append(dbs, fmt.Sprintf("%v", val)) + } + } + return dbs, nil +} + +func (v *VastbaseDB) GetTables(dbName string) ([]string, error) { + query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg_%' ORDER BY schemaname, tablename" + data, _, err := v.Query(query) + if err != nil { + return nil, err + } + + var tables []string + for _, row := range data { + schema, okSchema := row["schemaname"] + name, okName := row["tablename"] + if okSchema && okName { + tables = append(tables, fmt.Sprintf("%v.%v", schema, name)) + continue + } + if okName { + tables = append(tables, fmt.Sprintf("%v", name)) + } + } + return tables, nil +} + +func (v *VastbaseDB) GetCreateStatement(dbName, tableName string) (string, error) { + return fmt.Sprintf("-- SHOW CREATE TABLE not fully supported for Vastbase in this version.\n-- Table: %s", tableName), nil +} + +func (v *VastbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, + pg_get_expr(ad.adbin, ad.adrelid) AS column_default, + col_description(a.attrelid, a.attnum) AS comment, + CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key +FROM pg_class c +JOIN pg_namespace n ON n.oid = c.relnamespace +JOIN pg_attribute a ON a.attrelid = c.oid +LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum +LEFT JOIN ( + SELECT i.indrelid, a3.attname + FROM pg_index i + JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey) + WHERE i.indisprimary +) pk ON pk.indrelid = c.oid AND pk.attname = a.attname +WHERE c.relkind IN ('r', 'p') + AND n.nspname = '%s' + AND c.relname = '%s' + AND a.attnum > 0 + AND NOT a.attisdropped +ORDER BY a.attnum`, esc(schema), esc(table)) + + data, _, err := v.Query(query) + if err != nil { + return nil, err + } + + var columns []connection.ColumnDefinition + for _, row := range data { + col := connection.ColumnDefinition{ + Name: fmt.Sprintf("%v", row["column_name"]), + Type: fmt.Sprintf("%v", row["data_type"]), + Nullable: fmt.Sprintf("%v", row["is_nullable"]), + Key: fmt.Sprintf("%v", row["column_key"]), + Extra: "", + Comment: "", + } + + if val, ok := row["comment"]; ok && val != nil { + col.Comment = fmt.Sprintf("%v", val) + } + + if val, ok := row["column_default"]; ok && val != nil { + def := fmt.Sprintf("%v", val) + col.Default = &def + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") { + col.Extra = "auto_increment" + } + } + + columns = append(columns, col) + } + return columns, nil +} + +func (v *VastbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT + i.relname AS index_name, + a.attname AS column_name, + ix.indisunique AS is_unique, + x.ordinality AS seq_in_index, + am.amname AS index_type +FROM pg_class t +JOIN pg_namespace n ON n.oid = t.relnamespace +JOIN pg_index ix ON t.oid = ix.indrelid +JOIN pg_class i ON i.oid = ix.indexrelid +JOIN pg_am am ON i.relam = am.oid +JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE +JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum +WHERE t.relkind IN ('r', 'p') + AND t.relname = '%s' + AND n.nspname = '%s' +ORDER BY i.relname, x.ordinality`, esc(table), esc(schema)) + + data, _, err := v.Query(query) + if err != nil { + return nil, err + } + + parseBool := func(val interface{}) bool { + switch v := val.(type) { + case bool: + return v + case string: + s := strings.ToLower(strings.TrimSpace(v)) + return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes" + default: + s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", val))) + return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes" + } + } + + parseInt := func(val interface{}) int { + switch v := val.(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + case string: + var n int + _, _ = fmt.Sscanf(strings.TrimSpace(v), "%d", &n) + return n + default: + var n int + _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", val)), "%d", &n) + return n + } + } + + var indexes []connection.IndexDefinition + for _, row := range data { + isUnique := false + if val, ok := row["is_unique"]; ok && val != nil { + isUnique = parseBool(val) + } + + nonUnique := 1 + if isUnique { + nonUnique = 0 + } + + seq := 0 + if val, ok := row["seq_in_index"]; ok && val != nil { + seq = parseInt(val) + } + + indexType := "" + if val, ok := row["index_type"]; ok && val != nil { + indexType = strings.ToUpper(fmt.Sprintf("%v", val)) + } + if indexType == "" { + indexType = "BTREE" + } + + idx := connection.IndexDefinition{ + Name: fmt.Sprintf("%v", row["index_name"]), + ColumnName: fmt.Sprintf("%v", row["column_name"]), + NonUnique: nonUnique, + SeqInIndex: seq, + IndexType: indexType, + } + indexes = append(indexes, idx) + } + return indexes, nil +} + +func (v *VastbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT + tc.constraint_name AS constraint_name, + kcu.column_name AS column_name, + ccu.table_schema AS foreign_table_schema, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name +FROM information_schema.table_constraints AS tc +JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema +JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema +WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = '%s' + AND tc.table_schema = '%s' +ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema)) + + data, _, err := v.Query(query) + if err != nil { + return nil, err + } + + var fks []connection.ForeignKeyDefinition + for _, row := range data { + refSchema := "" + if val, ok := row["foreign_table_schema"]; ok && val != nil { + refSchema = fmt.Sprintf("%v", val) + } + refTable := fmt.Sprintf("%v", row["foreign_table_name"]) + refTableName := refTable + if strings.TrimSpace(refSchema) != "" { + refTableName = fmt.Sprintf("%s.%s", refSchema, refTable) + } + + fk := connection.ForeignKeyDefinition{ + Name: fmt.Sprintf("%v", row["constraint_name"]), + ColumnName: fmt.Sprintf("%v", row["column_name"]), + RefTableName: refTableName, + RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]), + ConstraintName: fmt.Sprintf("%v", row["constraint_name"]), + } + fks = append(fks, fk) + } + return fks, nil +} + +func (v *VastbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + schema := strings.TrimSpace(dbName) + if schema == "" { + schema = "public" + } + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") } + + query := fmt.Sprintf(` +SELECT trigger_name, action_timing, event_manipulation, action_statement +FROM information_schema.triggers +WHERE event_object_table = '%s' + AND event_object_schema = '%s' +ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema)) + + data, _, err := v.Query(query) + if err != nil { + return nil, err + } + + var triggers []connection.TriggerDefinition + for _, row := range data { + trig := connection.TriggerDefinition{ + Name: fmt.Sprintf("%v", row["trigger_name"]), + Timing: fmt.Sprintf("%v", row["action_timing"]), + Event: fmt.Sprintf("%v", row["event_manipulation"]), + Statement: fmt.Sprintf("%v", row["action_statement"]), + } + triggers = append(triggers, trig) + } + return triggers, nil +} + +func (v *VastbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + query := ` +SELECT table_schema, table_name, column_name, data_type +FROM information_schema.columns +WHERE table_schema NOT IN ('pg_catalog', 'information_schema') + AND table_schema NOT LIKE 'pg_%' +ORDER BY table_schema, table_name, ordinal_position` + + data, _, err := v.Query(query) + if err != nil { + return nil, err + } + + var cols []connection.ColumnDefinitionWithTable + for _, row := range data { + schema := fmt.Sprintf("%v", row["table_schema"]) + table := fmt.Sprintf("%v", row["table_name"]) + tableName := table + if strings.TrimSpace(schema) != "" { + tableName = fmt.Sprintf("%s.%s", schema, table) + } + + col := connection.ColumnDefinitionWithTable{ + TableName: tableName, + Name: fmt.Sprintf("%v", row["column_name"]), + Type: fmt.Sprintf("%v", row["data_type"]), + } + cols = append(cols, col) + } + return cols, nil +} + +func (v *VastbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if v.conn == nil { + return fmt.Errorf("connection not open") + } + + tx, err := v.conn.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + quoteIdent := func(name string) string { + n := strings.TrimSpace(name) + n = strings.Trim(n, "\"") + n = strings.ReplaceAll(n, "\"", "\"\"") + if n == "" { + return "\"\"" + } + return `"` + n + `"` + } + + schema := "" + table := strings.TrimSpace(tableName) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + qualifiedTable := "" + if schema != "" { + qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) + } else { + qualifiedTable = quoteIdent(table) + } + + // 1. Deletes + for _, pk := range changes.Deletes { + var wheres []string + var args []interface{} + idx := 0 + for k, val := range pk { + idx++ + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + args = append(args, val) + } + if len(wheres) == 0 { + continue + } + query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("delete error: %v", err) + } + } + + // 2. Updates + for _, update := range changes.Updates { + var sets []string + var args []interface{} + idx := 0 + + for k, val := range update.Values { + idx++ + sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + args = append(args, val) + } + + if len(sets) == 0 { + continue + } + + var wheres []string + for k, val := range update.Keys { + idx++ + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + args = append(args, val) + } + + if len(wheres) == 0 { + return fmt.Errorf("update requires keys") + } + + query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("update error: %v", err) + } + } + + // 3. Inserts + for _, row := range changes.Inserts { + var cols []string + var placeholders []string + var args []interface{} + idx := 0 + + for k, val := range row { + idx++ + cols = append(cols, quoteIdent(k)) + placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) + args = append(args, val) + } + + if len(cols) == 0 { + continue + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) + if _, err := tx.Exec(query, args...); err != nil { + return fmt.Errorf("insert error: %v", err) + } + } + + return tx.Commit() +} diff --git a/internal/sync/sql_helpers.go b/internal/sync/sql_helpers.go index 53b4bd2..708c6cb 100644 --- a/internal/sync/sql_helpers.go +++ b/internal/sync/sql_helpers.go @@ -22,8 +22,11 @@ func quoteIdentByType(dbType string, ident string) string { } switch dbType { - case "mysql": + case "mysql", "mariadb": return "`" + strings.ReplaceAll(ident, "`", "``") + "`" + case "sqlserver": + escaped := strings.ReplaceAll(ident, "]", "]]") + return "[" + escaped + "]" default: return `"` + strings.ReplaceAll(ident, `"`, `""`) + `"` } @@ -71,7 +74,7 @@ func normalizeSchemaAndTable(dbType string, dbName string, tableName string) (st } switch strings.ToLower(strings.TrimSpace(dbType)) { - case "postgres", "kingbase": + case "postgres", "kingbase", "vastbase": return "public", rawTable default: return rawDB, rawTable @@ -88,7 +91,7 @@ func qualifiedNameForQuery(dbType string, schema string, table string, original } switch strings.ToLower(strings.TrimSpace(dbType)) { - case "postgres", "kingbase": + case "postgres", "kingbase", "vastbase": s := strings.TrimSpace(schema) if s == "" { s = "public" @@ -97,7 +100,7 @@ func qualifiedNameForQuery(dbType string, schema string, table string, original return raw } return s + "." + table - case "mysql": + case "mysql", "mariadb": s := strings.TrimSpace(schema) if s == "" || table == "" { return table