mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-12 12:19:47 +08:00
Compare commits
30 Commits
v0.3.4
...
release/0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ecf47da81b | ||
|
|
21c8b9a102 | ||
|
|
a07b418b8f | ||
|
|
4bf10e5612 | ||
|
|
e6fe6eb026 | ||
|
|
b4f80f39df | ||
|
|
4d32dd2cb5 | ||
|
|
de8fb60a30 | ||
|
|
52abed83e6 | ||
|
|
80dc863455 | ||
|
|
fa318a9f0e | ||
|
|
78e35a5be8 | ||
|
|
35ed555857 | ||
|
|
f3130ff517 | ||
|
|
012c99be9e | ||
|
|
c8575c315b | ||
|
|
601d69faeb | ||
|
|
fdb7781a9b | ||
|
|
087578693e | ||
|
|
aceabb63f5 | ||
|
|
8587f72f81 | ||
|
|
83ad3b09d9 | ||
|
|
72811092b4 | ||
|
|
b67135e2c1 | ||
|
|
f5e16b0b70 | ||
|
|
f8535dd272 | ||
|
|
5cd8681b80 | ||
|
|
4b381c82b5 | ||
|
|
820b064e7f | ||
|
|
70cb6148c6 |
@@ -37,6 +37,7 @@
|
||||
- **Oracle**:基础数据访问与编辑支持。
|
||||
- **Dameng(达梦)**:基础数据访问与编辑支持。
|
||||
- **Kingbase(人大金仓)**:基础数据访问与编辑支持。
|
||||
- **TDengine**:时序数据库连接、库表浏览与 SQL 查询支持。
|
||||
- **Redis**:Key/Value 浏览、命令执行、视图与编码切换。
|
||||
- **自定义驱动**:支持配置 Driver/DSN 接入更多数据源。
|
||||
- **SSH 隧道**:内置 SSH 隧道支持,安全连接内网数据库。
|
||||
|
||||
164
docs/HighGo_Optional_Code_Changes.md
Normal file
164
docs/HighGo_Optional_Code_Changes.md
Normal file
@@ -0,0 +1,164 @@
|
||||
# HighGo 可选代码优化建议
|
||||
|
||||
## 一、sslmode 配置优化
|
||||
|
||||
### 当前状态
|
||||
|
||||
**文件**:`internal/db/highgo_impl.go:43`
|
||||
|
||||
**当前代码**:
|
||||
```go
|
||||
q.Set("sslmode", "disable")
|
||||
```
|
||||
|
||||
### 建议修改
|
||||
|
||||
根据瀚高官方文档,sslmode 的默认值应该是 `require`。建议修改为:
|
||||
|
||||
```go
|
||||
q.Set("sslmode", "require")
|
||||
```
|
||||
|
||||
### 修改原因
|
||||
|
||||
1. **符合官方规范**:瀚高官方文档明确指出默认 sslmode 为 `require`
|
||||
2. **安全性提升**:启用 SSL 加密可以保护数据传输安全
|
||||
3. **生产环境最佳实践**:生产环境应该启用 SSL 连接
|
||||
|
||||
### 是否需要修改?
|
||||
|
||||
**不一定需要修改**,取决于您的实际环境:
|
||||
|
||||
#### 保持 `disable` 的场景:
|
||||
- ✅ 开发/测试环境
|
||||
- ✅ HighGo 服务器未配置 SSL 证书
|
||||
- ✅ 内网环境,不需要加密传输
|
||||
- ✅ 快速测试连接功能
|
||||
|
||||
#### 修改为 `require` 的场景:
|
||||
- ✅ 生产环境
|
||||
- ✅ HighGo 服务器已配置 SSL 证书
|
||||
- ✅ 跨网络连接,需要加密保护
|
||||
- ✅ 符合安全合规要求
|
||||
|
||||
### 如何修改
|
||||
|
||||
如果您决定修改,可以使用以下命令:
|
||||
|
||||
**方式 1:直接修改(固定为 require)**
|
||||
```go
|
||||
// 文件:internal/db/highgo_impl.go 第 43 行
|
||||
q.Set("sslmode", "require")
|
||||
```
|
||||
|
||||
**方式 2:可配置(推荐)**
|
||||
|
||||
如果希望让用户可以选择 sslmode,可以修改为:
|
||||
|
||||
```go
|
||||
// 在 getDSN 方法中
|
||||
sslmode := "disable" // 默认值
|
||||
if config.SSLMode != "" {
|
||||
sslmode = config.SSLMode
|
||||
}
|
||||
q.Set("sslmode", sslmode)
|
||||
```
|
||||
|
||||
然后在 `internal/connection/connection.go` 的 `ConnectionConfig` 结构体中添加字段:
|
||||
|
||||
```go
|
||||
type ConnectionConfig struct {
|
||||
// ... 现有字段
|
||||
SSLMode string `json:"sslMode,omitempty"` // SSL 模式:disable, require, verify-ca, verify-full
|
||||
}
|
||||
```
|
||||
|
||||
前端 UI 也需要相应添加 sslmode 选择控件。
|
||||
|
||||
### 测试建议
|
||||
|
||||
修改后请务必测试:
|
||||
|
||||
1. **SSL 启用测试**:
|
||||
- 连接配置了 SSL 的 HighGo 服务器
|
||||
- 验证连接成功
|
||||
|
||||
2. **SSL 禁用测试**:
|
||||
- 连接未配置 SSL 的 HighGo 服务器
|
||||
- 验证是否会报错(如果设置为 `require` 会报错)
|
||||
|
||||
3. **兼容性测试**:
|
||||
- 测试现有的 HighGo 连接配置是否仍然可用
|
||||
|
||||
## 二、其他可选优化
|
||||
|
||||
### 1. 默认端口提示优化
|
||||
|
||||
**文件**:`frontend/src/components/ConnectionModal.tsx`
|
||||
|
||||
**当前状态**:HighGo 的默认端口已正确设置为 5866
|
||||
|
||||
**建议**:无需修改,已符合官方规范
|
||||
|
||||
### 2. 默认数据库名称
|
||||
|
||||
**文件**:`internal/db/highgo_impl.go:33`
|
||||
|
||||
**当前代码**:
|
||||
```go
|
||||
if dbname == "" {
|
||||
dbname = "highgo" // HighGo default database
|
||||
}
|
||||
```
|
||||
|
||||
**建议**:无需修改,已符合官方规范(默认数据库为 `highgo`)
|
||||
|
||||
### 3. 默认用户名
|
||||
|
||||
**当前状态**:未在代码中硬编码默认用户名
|
||||
|
||||
**瀚高官方默认**:`sysdba`
|
||||
|
||||
**建议**:
|
||||
- 可以在前端 UI 的 HighGo 连接表单中,将用户名输入框的 placeholder 设置为 `sysdba`
|
||||
- 但不建议硬编码默认值,让用户自行输入更安全
|
||||
|
||||
## 三、总结
|
||||
|
||||
### 必须修改的项目
|
||||
- ✅ **无**(当前代码已基本符合规范)
|
||||
|
||||
### 建议修改的项目
|
||||
1. **sslmode 配置**(根据实际环境决定)
|
||||
- 开发环境:保持 `disable`
|
||||
- 生产环境:修改为 `require`
|
||||
|
||||
### 可选优化的项目
|
||||
1. 将 sslmode 改为可配置(需要修改前后端)
|
||||
2. 前端 UI 添加 sslmode 选择控件
|
||||
3. 用户名输入框添加 `sysdba` 提示
|
||||
|
||||
## 四、修改优先级
|
||||
|
||||
**优先级 1(高)**:
|
||||
- 集成瀚高 SM3 驱动(参考 `HighGo_SM3_Integration_Guide.md`)
|
||||
|
||||
**优先级 2(中)**:
|
||||
- 根据部署环境调整 sslmode 配置
|
||||
|
||||
**优先级 3(低)**:
|
||||
- 将 sslmode 改为可配置
|
||||
- UI 优化(placeholder 提示等)
|
||||
|
||||
## 五、下一步行动
|
||||
|
||||
建议按以下顺序执行:
|
||||
|
||||
1. **先集成 SM3 驱动**(参考集成指南)
|
||||
2. **测试基本连接功能**(使用 sslmode=disable)
|
||||
3. **如果生产环境需要 SSL**,再修改 sslmode 配置
|
||||
4. **验证所有功能正常**后,考虑可选优化项
|
||||
|
||||
---
|
||||
|
||||
**注意**:所有代码修改都应该在集成 SM3 驱动并验证基本功能正常后再进行。
|
||||
196
docs/HighGo_SM3_Integration_Guide.md
Normal file
196
docs/HighGo_SM3_Integration_Guide.md
Normal file
@@ -0,0 +1,196 @@
|
||||
# HighGo SM3 国密驱动集成指南
|
||||
|
||||
## 一、背景说明
|
||||
|
||||
HighGo(瀚高)数据库需要使用支持 SM3 国密认证的 PostgreSQL 驱动。瀚高官方提供了基于 `lib/pq` 的安全增强版本。
|
||||
|
||||
## 二、集成步骤
|
||||
|
||||
### 步骤 1:下载瀚高 pq 驱动
|
||||
|
||||
1. 访问百度网盘链接:
|
||||
```
|
||||
https://pan.baidu.com/s/1xuz6uJz0utRgKWecXhpOiA?pwd=o0tj
|
||||
```
|
||||
|
||||
2. 下载驱动源码压缩包
|
||||
|
||||
### 步骤 2:放置驱动源码
|
||||
|
||||
1. 在项目根目录创建目录(如果不存在):
|
||||
```bash
|
||||
mkdir -p third_party/highgo-pq
|
||||
```
|
||||
|
||||
2. 解压下载的驱动源码到 `third_party/highgo-pq/` 目录
|
||||
|
||||
3. 确保目录结构如下:
|
||||
```
|
||||
GoNavi/
|
||||
├── third_party/
|
||||
│ └── highgo-pq/
|
||||
│ ├── go.mod
|
||||
│ ├── conn.go
|
||||
│ ├── ... (其他 pq 驱动源文件)
|
||||
```
|
||||
|
||||
### 步骤 3:修改 go.mod
|
||||
|
||||
在 `go.mod` 中添加独立的 HighGo 驱动依赖与本地替换:
|
||||
|
||||
```go
|
||||
require github.com/highgo/pq-sm3 v0.0.0
|
||||
replace github.com/highgo/pq-sm3 => ./third_party/highgo-pq
|
||||
```
|
||||
|
||||
完整示例:
|
||||
```go
|
||||
module GoNavi-Wails
|
||||
|
||||
go 1.24.3
|
||||
|
||||
require (
|
||||
// ... 现有依赖
|
||||
github.com/lib/pq v1.11.1
|
||||
github.com/highgo/pq-sm3 v0.0.0
|
||||
// ... 其他依赖
|
||||
)
|
||||
|
||||
// 在文件末尾添加
|
||||
replace github.com/highgo/pq-sm3 => ./third_party/highgo-pq
|
||||
```
|
||||
|
||||
并将 `third_party/highgo-pq/go.mod` 的 module 修改为:
|
||||
|
||||
```go
|
||||
module github.com/highgo/pq-sm3
|
||||
```
|
||||
|
||||
同时在驱动源码中把注册名改为 `highgo`,确保不覆盖 `postgres`:
|
||||
|
||||
```go
|
||||
sql.Register("highgo", &Driver{})
|
||||
```
|
||||
|
||||
### 步骤 4:更新 HighGo 连接配置(可选)
|
||||
|
||||
根据瀚高官方文档,建议修改 `internal/db/highgo_impl.go:43` 的 sslmode:
|
||||
|
||||
**当前代码**:
|
||||
```go
|
||||
q.Set("sslmode", "disable")
|
||||
```
|
||||
|
||||
**建议修改为**(瀚高默认):
|
||||
```go
|
||||
q.Set("sslmode", "require")
|
||||
```
|
||||
|
||||
> ⚠️ 注意:如果您的 HighGo 服务器未配置 SSL,保持 `disable` 即可。
|
||||
|
||||
### 步骤 5:验证集成
|
||||
|
||||
1. 清理依赖缓存:
|
||||
```bash
|
||||
go clean -modcache
|
||||
```
|
||||
|
||||
2. 重新下载依赖:
|
||||
```bash
|
||||
go mod download
|
||||
```
|
||||
|
||||
3. 编译项目:
|
||||
```bash
|
||||
go build ./...
|
||||
```
|
||||
|
||||
4. 测试 HighGo 连接:
|
||||
- 启动应用
|
||||
- 创建 HighGo 连接
|
||||
- 测试连接是否成功
|
||||
|
||||
## 三、重要说明
|
||||
|
||||
### ⚠️ 影响范围
|
||||
|
||||
采用独立驱动名后,影响范围如下:
|
||||
|
||||
1. **PostgreSQL 继续使用原生 `github.com/lib/pq`**
|
||||
2. **HighGo 使用 `github.com/highgo/pq-sm3`(本地替换到官方源码)**
|
||||
3. 两条连接链路互不覆盖,降低兼容性风险
|
||||
|
||||
### 兼容性验证
|
||||
|
||||
集成后,请务必测试:
|
||||
|
||||
1. ✅ HighGo 数据库连接(SM3 认证)
|
||||
2. ✅ 标准 PostgreSQL 连接(确保仍然可用)
|
||||
|
||||
若 PostgreSQL 或 HighGo 任一连接异常,优先检查驱动注册名与 `go.mod` replace 是否一致。
|
||||
|
||||
### 回滚方案
|
||||
|
||||
如果集成后出现问题,可以快速回滚:
|
||||
|
||||
1. 删除 `go.mod` 中的 replace 指令
|
||||
2. 删除 `go.mod` 中 `github.com/highgo/pq-sm3` 的 require
|
||||
3. 删除 `third_party/highgo-pq/` 目录
|
||||
4. 运行 `go mod tidy`
|
||||
5. 重新编译
|
||||
|
||||
## 四、瀚高驱动特性
|
||||
|
||||
根据官方文档:
|
||||
|
||||
- **项目内包路径**:`github.com/highgo/pq-sm3`(映射到本地 `third_party/highgo-pq`)
|
||||
- **驱动名**:`highgo`(项目内独立注册,避免覆盖 `postgres`)
|
||||
- **SM3 支持**:自动启用国密认证
|
||||
- **默认端口**:5866
|
||||
- **默认数据库**:`highgo`
|
||||
- **默认用户**:`sysdba`
|
||||
- **sslmode 默认**:`require`
|
||||
|
||||
## 五、故障排查
|
||||
|
||||
### 问题 1:编译失败
|
||||
|
||||
**现象**:`go build` 报错找不到 `github.com/highgo/pq-sm3`
|
||||
|
||||
**解决**:
|
||||
1. 检查 `third_party/highgo-pq/` 目录是否存在
|
||||
2. 检查 `go.mod` 中 `github.com/highgo/pq-sm3` 的 require/replace 是否正确
|
||||
3. 运行 `go mod download`
|
||||
|
||||
### 问题 2:HighGo 连接失败
|
||||
|
||||
**现象**:连接 HighGo 时报认证错误
|
||||
|
||||
**解决**:
|
||||
1. 确认瀚高驱动已正确替换(检查 `go.mod`)
|
||||
2. 确认项目内驱动注册名为 `highgo`
|
||||
3. 确认 HighGo 服务器支持 SM3 认证
|
||||
4. 检查用户名、密码、端口是否正确
|
||||
|
||||
### 问题 3:PostgreSQL 连接失败
|
||||
|
||||
**现象**:集成后标准 PostgreSQL 无法连接
|
||||
|
||||
**解决**:
|
||||
1. 检查是否误将 `github.com/lib/pq` 全局 replace 到 HighGo 驱动
|
||||
2. 确认 PostgreSQL 仍使用 `sql.Open("postgres", dsn)`
|
||||
3. 确认 HighGo 使用 `sql.Open("highgo", dsn)`
|
||||
|
||||
## 六、后续优化建议
|
||||
|
||||
如果后续需要增强,可考虑:
|
||||
|
||||
1. 将 HighGo `sslmode` 做成可配置项(前后端联动)
|
||||
2. 增加 HighGo/PG 驱动链路健康检查项
|
||||
3. 联系瀚高技术支持确认 SM3 + SSL 最佳参数组合
|
||||
|
||||
## 七、参考资料
|
||||
|
||||
- 瀚高官方文档:https://www.highgo.com/document/zh-cn/application/pq%E6%8E%A5%E5%8F%A3.html
|
||||
- 瀚高驱动下载:https://pan.baidu.com/s/1xuz6uJz0utRgKWecXhpOiA?pwd=o0tj
|
||||
- 标准 lib/pq:https://github.com/lib/pq
|
||||
1
frontend/.gitignore
vendored
Normal file
1
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.ace-tool/
|
||||
@@ -63,8 +63,28 @@ body {
|
||||
}
|
||||
|
||||
body[data-theme='dark'] {
|
||||
/* Improve contrast on transparent backgrounds */
|
||||
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.8);
|
||||
/* 移除全局 text-shadow:对每个文本元素增加 GPU compositing 成本,
|
||||
在透明窗口环境下会显著加剧 GPU 负载 */
|
||||
}
|
||||
|
||||
/* 连接配置弹窗:滚动仅在弹窗 body 内部,不使用外层 wrap 滚动条 */
|
||||
.connection-modal-wrap {
|
||||
overflow: hidden !important;
|
||||
}
|
||||
|
||||
.connection-modal-wrap .ant-modal-content {
|
||||
max-height: calc(100vh - 72px);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.connection-modal-wrap .ant-modal-body {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.connection-modal-wrap .ant-modal-footer {
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Custom Title Bar Close Button Hover */
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Layout, Button, ConfigProvider, theme, Dropdown, MenuProps, message, Modal, Spin, Slider, Popover } from 'antd';
|
||||
import { Layout, Button, ConfigProvider, theme, Dropdown, MenuProps, message, Modal, Spin, Slider, Progress } from 'antd';
|
||||
import zhCN from 'antd/locale/zh_CN';
|
||||
import { PlusOutlined, BulbOutlined, BulbFilled, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined } from '@ant-design/icons';
|
||||
import { Environment, EventsOn } from '../wailsjs/runtime/runtime';
|
||||
import Sidebar from './components/Sidebar';
|
||||
import TabManager from './components/TabManager';
|
||||
import ConnectionModal from './components/ConnectionModal';
|
||||
@@ -9,7 +10,8 @@ import DataSyncModal from './components/DataSyncModal';
|
||||
import LogPanel from './components/LogPanel';
|
||||
import { useStore } from './store';
|
||||
import { SavedConnection } from './types';
|
||||
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform } from './utils/appearance';
|
||||
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform } from './utils/appearance';
|
||||
import { SetWindowTranslucency } from '../wailsjs/go/app/App';
|
||||
import './App.css';
|
||||
|
||||
const { Sider, Content } = Layout;
|
||||
@@ -27,6 +29,30 @@ function App() {
|
||||
const effectiveBlur = normalizeBlurForPlatform(appearance.blur);
|
||||
const blurFilter = blurToFilter(effectiveBlur);
|
||||
const windowCornerRadius = 14;
|
||||
const [isLinuxRuntime, setIsLinuxRuntime] = useState(false);
|
||||
|
||||
// 同步 macOS 窗口透明度:opacity=1.0 且 blur=0 时关闭 NSVisualEffectView,
|
||||
// 避免 GPU 持续计算窗口背后的模糊合成
|
||||
useEffect(() => {
|
||||
SetWindowTranslucency(appearance.opacity, appearance.blur).catch(() => {});
|
||||
}, [appearance.opacity, appearance.blur]);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
Environment()
|
||||
.then((env) => {
|
||||
if (cancelled) return;
|
||||
setIsLinuxRuntime((env?.platform || '').toLowerCase() === 'linux');
|
||||
})
|
||||
.catch(() => {
|
||||
if (cancelled) return;
|
||||
const platform = typeof navigator !== 'undefined' ? navigator.platform : '';
|
||||
setIsLinuxRuntime(/linux/i.test(platform));
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Background Helper
|
||||
const getBg = (darkHex: string, lightHex: string) => {
|
||||
@@ -52,6 +78,7 @@ function App() {
|
||||
const updateCheckInFlightRef = React.useRef(false);
|
||||
const updateDownloadInFlightRef = React.useRef(false);
|
||||
const updateDownloadedVersionRef = React.useRef<string | null>(null);
|
||||
const updateDownloadMetaRef = React.useRef<UpdateDownloadResultData | null>(null);
|
||||
const updateDeferredVersionRef = React.useRef<string | null>(null);
|
||||
const updateNotifiedVersionRef = React.useRef<string | null>(null);
|
||||
const updateMutedVersionRef = React.useRef<string | null>(null);
|
||||
@@ -60,6 +87,23 @@ function App() {
|
||||
const [aboutInfo, setAboutInfo] = useState<{ version: string; author: string; buildTime?: string; repoUrl?: string; issueUrl?: string; releaseUrl?: string } | null>(null);
|
||||
const [aboutUpdateStatus, setAboutUpdateStatus] = useState<string>('');
|
||||
const [lastUpdateInfo, setLastUpdateInfo] = useState<UpdateInfo | null>(null);
|
||||
const [updateDownloadProgress, setUpdateDownloadProgress] = useState<{
|
||||
open: boolean;
|
||||
version: string;
|
||||
status: 'idle' | 'start' | 'downloading' | 'done' | 'error';
|
||||
percent: number;
|
||||
downloaded: number;
|
||||
total: number;
|
||||
message: string;
|
||||
}>({
|
||||
open: false,
|
||||
version: '',
|
||||
status: 'idle',
|
||||
percent: 0,
|
||||
downloaded: 0,
|
||||
total: 0,
|
||||
message: ''
|
||||
});
|
||||
|
||||
type UpdateInfo = {
|
||||
hasUpdate: boolean;
|
||||
@@ -73,10 +117,51 @@ function App() {
|
||||
sha256?: string;
|
||||
};
|
||||
|
||||
const promptRestartForUpdate = (info: UpdateInfo) => {
|
||||
type UpdateDownloadProgressEvent = {
|
||||
status?: 'start' | 'downloading' | 'done' | 'error';
|
||||
percent?: number;
|
||||
downloaded?: number;
|
||||
total?: number;
|
||||
message?: string;
|
||||
};
|
||||
|
||||
type UpdateDownloadResultData = {
|
||||
info?: UpdateInfo;
|
||||
downloadPath?: string;
|
||||
installLogPath?: string;
|
||||
installTarget?: string;
|
||||
platform?: string;
|
||||
autoRelaunch?: boolean;
|
||||
};
|
||||
|
||||
const formatBytes = (bytes?: number) => {
|
||||
if (!bytes || bytes <= 0) return '0 B';
|
||||
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
|
||||
let value = bytes;
|
||||
let idx = 0;
|
||||
while (value >= 1024 && idx < units.length - 1) {
|
||||
value /= 1024;
|
||||
idx++;
|
||||
}
|
||||
return `${value.toFixed(idx === 0 ? 0 : 1)} ${units[idx]}`;
|
||||
};
|
||||
|
||||
const promptRestartForUpdate = (info: UpdateInfo, resultData?: UpdateDownloadResultData) => {
|
||||
const downloadPathHint = resultData?.downloadPath
|
||||
? `更新包路径:${resultData.downloadPath}`
|
||||
: '';
|
||||
const installLogHint = resultData?.installLogPath
|
||||
? `安装日志:${resultData.installLogPath}`
|
||||
: '';
|
||||
Modal.confirm({
|
||||
title: '更新已下载',
|
||||
content: `版本 ${info.latestVersion} 已下载完成,是否现在重启完成更新?`,
|
||||
content: (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 6, userSelect: 'text' }}>
|
||||
<div>{`版本 ${info.latestVersion} 已下载完成,是否现在重启完成更新?`}</div>
|
||||
{downloadPathHint ? <div style={{ fontSize: 12, color: '#8c8c8c' }}>{downloadPathHint}</div> : null}
|
||||
{installLogHint ? <div style={{ fontSize: 12, color: '#8c8c8c' }}>{installLogHint}</div> : null}
|
||||
</div>
|
||||
),
|
||||
okText: '立即重启',
|
||||
cancelText: '稍后',
|
||||
onOk: async () => {
|
||||
@@ -96,25 +181,49 @@ function App() {
|
||||
if (updateDownloadInFlightRef.current) return;
|
||||
if (updateDownloadedVersionRef.current === info.latestVersion) {
|
||||
if (!silent) {
|
||||
message.info(`更新包已就绪(${info.latestVersion})`);
|
||||
const cachedDownloadPath = updateDownloadMetaRef.current?.downloadPath;
|
||||
message.info(cachedDownloadPath ? `更新包已就绪(${info.latestVersion}),路径:${cachedDownloadPath}` : `更新包已就绪(${info.latestVersion})`);
|
||||
}
|
||||
if (!silent || updateDeferredVersionRef.current !== info.latestVersion) {
|
||||
promptRestartForUpdate(info);
|
||||
promptRestartForUpdate(info, updateDownloadMetaRef.current || undefined);
|
||||
}
|
||||
return;
|
||||
}
|
||||
updateDownloadInFlightRef.current = true;
|
||||
updateDownloadMetaRef.current = null;
|
||||
const key = 'update-download';
|
||||
setUpdateDownloadProgress({
|
||||
open: true,
|
||||
version: info.latestVersion,
|
||||
status: 'start',
|
||||
percent: 0,
|
||||
downloaded: 0,
|
||||
total: info.assetSize || 0,
|
||||
message: ''
|
||||
});
|
||||
message.loading({ content: `正在下载更新 ${info.latestVersion}...`, key, duration: 0 });
|
||||
const res = await (window as any).go.app.App.DownloadUpdate();
|
||||
updateDownloadInFlightRef.current = false;
|
||||
if (res?.success) {
|
||||
const resultData = (res?.data || {}) as UpdateDownloadResultData;
|
||||
updateDownloadMetaRef.current = resultData;
|
||||
updateDownloadedVersionRef.current = info.latestVersion;
|
||||
message.success({ content: '更新下载完成', key, duration: 2 });
|
||||
setUpdateDownloadProgress(prev => ({ ...prev, status: 'done', percent: 100, open: false }));
|
||||
if (resultData?.downloadPath) {
|
||||
message.success({ content: `更新下载完成,更新包路径:${resultData.downloadPath}`, key, duration: 5 });
|
||||
} else {
|
||||
message.success({ content: '更新下载完成', key, duration: 2 });
|
||||
}
|
||||
setAboutUpdateStatus(`发现新版本 ${info.latestVersion}(已下载,待重启安装)`);
|
||||
if (!silent || updateDeferredVersionRef.current !== info.latestVersion) {
|
||||
promptRestartForUpdate(info);
|
||||
promptRestartForUpdate(info, resultData);
|
||||
}
|
||||
} else {
|
||||
setUpdateDownloadProgress(prev => ({
|
||||
...prev,
|
||||
status: 'error',
|
||||
message: res?.message || '未知错误'
|
||||
}));
|
||||
message.error({ content: '更新下载失败: ' + (res?.message || '未知错误'), key, duration: 4 });
|
||||
}
|
||||
}, []);
|
||||
@@ -277,8 +386,13 @@ function App() {
|
||||
const [isAppearanceModalOpen, setIsAppearanceModalOpen] = useState(false);
|
||||
|
||||
|
||||
// Log Panel
|
||||
const [logPanelHeight, setLogPanelHeight] = useState(200);
|
||||
// Log Panel: 最小高度按“工具栏 + 1 条日志行(微增)”限制
|
||||
const LOG_PANEL_TOOLBAR_HEIGHT = 32;
|
||||
const LOG_PANEL_SINGLE_ROW_HEIGHT = 39;
|
||||
const LOG_PANEL_MIN_VISIBLE_ROWS = 1;
|
||||
const LOG_PANEL_MIN_HEIGHT = LOG_PANEL_TOOLBAR_HEIGHT + (LOG_PANEL_SINGLE_ROW_HEIGHT * LOG_PANEL_MIN_VISIBLE_ROWS);
|
||||
const LOG_PANEL_MAX_HEIGHT = 800;
|
||||
const [logPanelHeight, setLogPanelHeight] = useState(Math.max(200, LOG_PANEL_MIN_HEIGHT));
|
||||
const [isLogPanelOpen, setIsLogPanelOpen] = useState(false);
|
||||
const logResizeRef = React.useRef<{ startY: number, startHeight: number } | null>(null);
|
||||
const logGhostRef = React.useRef<HTMLDivElement>(null);
|
||||
@@ -307,7 +421,10 @@ function App() {
|
||||
const handleLogResizeUp = (e: MouseEvent) => {
|
||||
if (logResizeRef.current) {
|
||||
const delta = logResizeRef.current.startY - e.clientY;
|
||||
const newHeight = Math.max(100, Math.min(800, logResizeRef.current.startHeight + delta));
|
||||
const newHeight = Math.max(
|
||||
LOG_PANEL_MIN_HEIGHT,
|
||||
Math.min(LOG_PANEL_MAX_HEIGHT, logResizeRef.current.startHeight + delta)
|
||||
);
|
||||
setLogPanelHeight(newHeight);
|
||||
}
|
||||
|
||||
@@ -329,6 +446,14 @@ function App() {
|
||||
setIsModalOpen(false);
|
||||
setEditingConnection(null);
|
||||
};
|
||||
|
||||
const handleTitleBarDoubleClick = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
const target = e.target as HTMLElement | null;
|
||||
if (target?.closest('[data-no-titlebar-toggle="true"]')) {
|
||||
return;
|
||||
}
|
||||
(window as any).runtime.WindowToggleMaximise();
|
||||
};
|
||||
|
||||
// Sidebar Resizing
|
||||
const [sidebarWidth, setSidebarWidth] = useState(300);
|
||||
@@ -422,6 +547,46 @@ function App() {
|
||||
};
|
||||
}, [checkForUpdates]);
|
||||
|
||||
useEffect(() => {
|
||||
const offDownloadProgress = EventsOn('update:download-progress', (event: UpdateDownloadProgressEvent) => {
|
||||
if (!event) return;
|
||||
const status = event.status || 'downloading';
|
||||
const nextStatus: 'idle' | 'start' | 'downloading' | 'done' | 'error' =
|
||||
status === 'start' || status === 'downloading' || status === 'done' || status === 'error'
|
||||
? status
|
||||
: 'downloading';
|
||||
const downloaded = typeof event.downloaded === 'number' ? event.downloaded : 0;
|
||||
const total = typeof event.total === 'number' ? event.total : 0;
|
||||
const percentRaw = typeof event.percent === 'number'
|
||||
? event.percent
|
||||
: (total > 0 ? (downloaded / total) * 100 : 0);
|
||||
const percent = Math.max(0, Math.min(100, percentRaw));
|
||||
setUpdateDownloadProgress(prev => ({
|
||||
open: nextStatus === 'start' || nextStatus === 'downloading' || nextStatus === 'error',
|
||||
version: prev.version,
|
||||
status: nextStatus,
|
||||
percent,
|
||||
downloaded,
|
||||
total,
|
||||
message: String(event.message || '')
|
||||
}));
|
||||
});
|
||||
return () => {
|
||||
offDownloadProgress();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const linuxResizeHandleStyleBase = {
|
||||
position: 'fixed',
|
||||
zIndex: 12000,
|
||||
background: 'transparent',
|
||||
WebkitAppRegion: 'drag',
|
||||
'--wails-draggable': 'drag',
|
||||
userSelect: 'none'
|
||||
} as any;
|
||||
|
||||
const showLinuxResizeHandles = isLinuxRuntime;
|
||||
|
||||
return (
|
||||
<ConfigProvider
|
||||
locale={zhCN}
|
||||
@@ -465,13 +630,14 @@ function App() {
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
background: 'transparent',
|
||||
borderRadius: windowCornerRadius,
|
||||
clipPath: `inset(0 round ${windowCornerRadius}px)`,
|
||||
borderRadius: showLinuxResizeHandles ? 0 : windowCornerRadius,
|
||||
clipPath: showLinuxResizeHandles ? 'none' : `inset(0 round ${windowCornerRadius}px)`,
|
||||
backdropFilter: blurFilter,
|
||||
WebkitBackdropFilter: blurFilter,
|
||||
}}>
|
||||
{/* Custom Title Bar */}
|
||||
<div
|
||||
onDoubleClick={handleTitleBarDoubleClick}
|
||||
style={{
|
||||
height: 32,
|
||||
flexShrink: 0,
|
||||
@@ -479,8 +645,6 @@ function App() {
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
background: bgMain,
|
||||
backdropFilter: blurFilter,
|
||||
WebkitBackdropFilter: blurFilter,
|
||||
borderBottom: 'none',
|
||||
userSelect: 'none',
|
||||
WebkitAppRegion: 'drag', // Wails drag region
|
||||
@@ -492,7 +656,11 @@ function App() {
|
||||
{/* Logo can be added here if available */}
|
||||
GoNavi
|
||||
</div>
|
||||
<div style={{ display: 'flex', height: '100%', WebkitAppRegion: 'no-drag', '--wails-draggable': 'no-drag' } as any}>
|
||||
<div
|
||||
data-no-titlebar-toggle="true"
|
||||
onDoubleClick={(e) => e.stopPropagation()}
|
||||
style={{ display: 'flex', height: '100%', WebkitAppRegion: 'no-drag', '--wails-draggable': 'no-drag' } as any}
|
||||
>
|
||||
<Button
|
||||
type="text"
|
||||
icon={<MinusOutlined />}
|
||||
@@ -527,8 +695,6 @@ function App() {
|
||||
padding: '0 8px',
|
||||
borderBottom: 'none',
|
||||
background: bgMain,
|
||||
backdropFilter: blurFilter,
|
||||
WebkitBackdropFilter: blurFilter,
|
||||
}}
|
||||
>
|
||||
<Dropdown menu={{ items: toolsMenu }} placement="bottomLeft">
|
||||
@@ -543,7 +709,7 @@ function App() {
|
||||
<Sider
|
||||
width={sidebarWidth}
|
||||
style={{
|
||||
borderRight: 'none',
|
||||
borderRight: '1px solid rgba(128,128,128,0.2)',
|
||||
position: 'relative',
|
||||
background: bgMain
|
||||
}}
|
||||
@@ -591,7 +757,7 @@ function App() {
|
||||
/>
|
||||
</Sider>
|
||||
<Content style={{ background: 'transparent', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column', background: bgContent, backdropFilter: blurFilter, WebkitBackdropFilter: blurFilter }}>
|
||||
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column', background: bgContent }}>
|
||||
<TabManager />
|
||||
</div>
|
||||
{isLogPanelOpen && (
|
||||
@@ -679,31 +845,104 @@ function App() {
|
||||
min={0.1}
|
||||
max={1.0}
|
||||
step={0.05}
|
||||
value={appearance.opacity ?? 0.95}
|
||||
value={appearance.opacity ?? 1.0}
|
||||
onChange={(v) => setAppearance({ opacity: v })}
|
||||
style={{ flex: 1 }}
|
||||
/>
|
||||
<span style={{ width: 40 }}>{Math.round((appearance.opacity ?? 0.95) * 100)}%</span>
|
||||
<span style={{ width: 40 }}>{Math.round((appearance.opacity ?? 1.0) * 100)}%</span>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div style={{ marginBottom: 8, fontWeight: 500 }}>高斯模糊 (Blur)</div>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 16 }}>
|
||||
<Slider
|
||||
min={0}
|
||||
max={20}
|
||||
value={appearance.blur ?? 0}
|
||||
onChange={(v) => setAppearance({ blur: v })}
|
||||
style={{ flex: 1 }}
|
||||
/>
|
||||
<span style={{ width: 40 }}>{appearance.blur}px</span>
|
||||
</div>
|
||||
<div style={{ fontSize: 12, color: '#888', marginTop: 4 }}>
|
||||
* 仅控制应用内覆盖层的模糊效果
|
||||
</div>
|
||||
{isWindowsPlatform() ? (
|
||||
<div style={{ fontSize: 12, color: '#888' }}>
|
||||
Windows 使用系统 Acrylic 效果,模糊程度由系统控制
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 16 }}>
|
||||
<Slider
|
||||
min={0}
|
||||
max={20}
|
||||
value={appearance.blur ?? 0}
|
||||
onChange={(v) => setAppearance({ blur: v })}
|
||||
style={{ flex: 1 }}
|
||||
/>
|
||||
<span style={{ width: 40 }}>{appearance.blur}px</span>
|
||||
</div>
|
||||
<div style={{ fontSize: 12, color: '#888', marginTop: 4 }}>
|
||||
* 仅控制应用内覆盖层的模糊效果
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
|
||||
<Modal
|
||||
title={updateDownloadProgress.version ? `下载更新 ${updateDownloadProgress.version}` : '下载更新'}
|
||||
open={updateDownloadProgress.open}
|
||||
closable={updateDownloadProgress.status === 'error'}
|
||||
maskClosable={false}
|
||||
keyboard={updateDownloadProgress.status === 'error'}
|
||||
onCancel={() => {
|
||||
if (updateDownloadProgress.status === 'error') {
|
||||
setUpdateDownloadProgress({
|
||||
open: false,
|
||||
version: '',
|
||||
status: 'idle',
|
||||
percent: 0,
|
||||
downloaded: 0,
|
||||
total: 0,
|
||||
message: ''
|
||||
});
|
||||
}
|
||||
}}
|
||||
footer={updateDownloadProgress.status === 'error' ? [
|
||||
<Button
|
||||
key="close"
|
||||
onClick={() => setUpdateDownloadProgress({
|
||||
open: false,
|
||||
version: '',
|
||||
status: 'idle',
|
||||
percent: 0,
|
||||
downloaded: 0,
|
||||
total: 0,
|
||||
message: ''
|
||||
})}
|
||||
>
|
||||
关闭
|
||||
</Button>
|
||||
] : null}
|
||||
>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
|
||||
<Progress
|
||||
percent={Math.round(updateDownloadProgress.percent)}
|
||||
status={updateDownloadProgress.status === 'error' ? 'exception' : (updateDownloadProgress.status === 'done' ? 'success' : 'active')}
|
||||
/>
|
||||
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
|
||||
{`${formatBytes(updateDownloadProgress.downloaded)} / ${formatBytes(updateDownloadProgress.total)}`}
|
||||
</div>
|
||||
{updateDownloadProgress.message ? (
|
||||
<div style={{ fontSize: 12, color: '#ff4d4f' }}>{updateDownloadProgress.message}</div>
|
||||
) : null}
|
||||
</div>
|
||||
</Modal>
|
||||
|
||||
{showLinuxResizeHandles && (
|
||||
<>
|
||||
{/* Linux Mint 下 frameless 仅局部可缩放:补四边四角命中层 */}
|
||||
<div style={{ ...linuxResizeHandleStyleBase, top: 0, left: 14, right: 14, height: 6, cursor: 'ns-resize' }} />
|
||||
<div style={{ ...linuxResizeHandleStyleBase, bottom: 0, left: 14, right: 14, height: 6, cursor: 'ns-resize' }} />
|
||||
<div style={{ ...linuxResizeHandleStyleBase, top: 14, bottom: 14, left: 0, width: 6, cursor: 'ew-resize' }} />
|
||||
<div style={{ ...linuxResizeHandleStyleBase, top: 14, bottom: 14, right: 0, width: 6, cursor: 'ew-resize' }} />
|
||||
|
||||
<div style={{ ...linuxResizeHandleStyleBase, top: 0, left: 0, width: 14, height: 14, cursor: 'nwse-resize' }} />
|
||||
<div style={{ ...linuxResizeHandleStyleBase, top: 0, right: 0, width: 14, height: 14, cursor: 'nesw-resize' }} />
|
||||
<div style={{ ...linuxResizeHandleStyleBase, bottom: 0, left: 0, width: 14, height: 14, cursor: 'nesw-resize' }} />
|
||||
<div style={{ ...linuxResizeHandleStyleBase, bottom: 0, right: 0, width: 14, height: 14, cursor: 'nwse-resize' }} />
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Ghost Resize Line for Sidebar */}
|
||||
<div
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ import { TabData, ColumnDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { buildWhereSQL, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
|
||||
import { buildOrderBySQL, buildWhereSQL, quoteQualifiedIdent, type FilterCondition } from '../utils/sql';
|
||||
|
||||
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [data, setData] = useState<any[]>([]);
|
||||
@@ -29,7 +29,9 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
|
||||
|
||||
const [showFilter, setShowFilter] = useState(false);
|
||||
const [filterConditions, setFilterConditions] = useState<any[]>([]);
|
||||
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>([]);
|
||||
const currentConnType = (connections.find(c => c.id === tab.connectionId)?.config?.type || '').toLowerCase();
|
||||
const forceReadOnly = currentConnType === 'tdengine';
|
||||
|
||||
useEffect(() => {
|
||||
setPkColumns([]);
|
||||
@@ -67,9 +69,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
|
||||
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
if (sortInfo && sortInfo.order) {
|
||||
sql += ` ORDER BY ${quoteIdentPart(dbType, sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
}
|
||||
sql += buildOrderBySQL(dbType, sortInfo, pkColumns);
|
||||
const offset = (page - 1) * size;
|
||||
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
|
||||
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
|
||||
@@ -203,13 +203,9 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
});
|
||||
}
|
||||
if (fetchSeqRef.current === seq) setLoading(false);
|
||||
}, [connections, tab, sortInfo, filterConditions, pkColumns.length]);
|
||||
// Depend on pkColumns.length to avoid loop? No, pkColumns is updated inside.
|
||||
// Actually, 'pkColumns' state shouldn't trigger re-fetch.
|
||||
// The 'if (pkColumns.length === 0)' check is inside.
|
||||
// So adding pkColumns to dependency is safer but might trigger double fetch if not careful?
|
||||
// Only if pkColumns changes. It changes once from [] to [...].
|
||||
// So it's fine.
|
||||
}, [connections, tab, sortInfo, filterConditions, pkColumns]);
|
||||
// 依赖 pkColumns:在无手动排序时可回退到主键稳定排序。
|
||||
// 主键信息只会在首次加载后更新一次,避免循环查询。
|
||||
|
||||
// Handlers memoized
|
||||
const handleReload = useCallback(() => {
|
||||
@@ -218,7 +214,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const handleSort = useCallback((field: string, order: string) => setSortInfo({ columnKey: field, order }), []);
|
||||
const handlePageChange = useCallback((page: number, size: number) => fetchData(page, size), [fetchData]);
|
||||
const handleToggleFilter = useCallback(() => setShowFilter(prev => !prev), []);
|
||||
const handleApplyFilter = useCallback((conditions: any[]) => setFilterConditions(conditions), []);
|
||||
const handleApplyFilter = useCallback((conditions: FilterCondition[]) => setFilterConditions(conditions), []);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData(1, pagination.pageSize);
|
||||
@@ -241,6 +237,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
showFilter={showFilter}
|
||||
onToggleFilter={handleToggleFilter}
|
||||
onApplyFilter={handleApplyFilter}
|
||||
readOnly={forceReadOnly}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
393
frontend/src/components/DefinitionViewer.tsx
Normal file
393
frontend/src/components/DefinitionViewer.tsx
Normal file
@@ -0,0 +1,393 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import Editor from '@monaco-editor/react';
|
||||
import { Spin, Alert } from 'antd';
|
||||
import { TabData } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery } from '../../wailsjs/go/app/App';
|
||||
|
||||
interface DefinitionViewerProps {
|
||||
tab: TabData;
|
||||
}
|
||||
|
||||
const DefinitionViewer: React.FC<DefinitionViewerProps> = ({ tab }) => {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [definition, setDefinition] = useState<string>('');
|
||||
|
||||
const connections = useStore(state => state.connections);
|
||||
const theme = useStore(state => state.theme);
|
||||
const darkMode = theme === 'dark';
|
||||
|
||||
const escapeSQLLiteral = (raw: string): string => String(raw || '').replace(/'/g, "''");
|
||||
|
||||
const getMetadataDialect = (conn: any): string => {
|
||||
const type = String(conn?.config?.type || '').trim().toLowerCase();
|
||||
if (type === 'custom') {
|
||||
return String(conn?.config?.driver || '').trim().toLowerCase();
|
||||
}
|
||||
if (type === 'mariadb' || type === 'sphinx') return 'mysql';
|
||||
if (type === 'dameng') return 'dm';
|
||||
return type;
|
||||
};
|
||||
|
||||
const isSphinxConnection = (conn: any): boolean => {
|
||||
const type = String(conn?.config?.type || '').trim().toLowerCase();
|
||||
if (type === 'sphinx') return true;
|
||||
if (type !== 'custom') return false;
|
||||
const driver = String(conn?.config?.driver || '').trim().toLowerCase();
|
||||
return driver === 'sphinx' || driver === 'sphinxql';
|
||||
};
|
||||
|
||||
const parseSchemaAndName = (fullName: string): { schema: string; name: string } => {
|
||||
const raw = String(fullName || '').trim();
|
||||
const idx = raw.lastIndexOf('.');
|
||||
if (idx > 0 && idx < raw.length - 1) {
|
||||
return { schema: raw.substring(0, idx), name: raw.substring(idx + 1) };
|
||||
}
|
||||
return { schema: '', name: raw };
|
||||
};
|
||||
|
||||
const buildShowViewQueries = (dialect: string, viewName: string, dbName: string): string[] => {
|
||||
const { schema, name } = parseSchemaAndName(viewName);
|
||||
const safeName = escapeSQLLiteral(name);
|
||||
const safeDbName = escapeSQLLiteral(dbName);
|
||||
|
||||
switch (dialect) {
|
||||
case 'mysql':
|
||||
return [
|
||||
`SHOW CREATE VIEW \`${name.replace(/`/g, '``')}\``,
|
||||
safeDbName
|
||||
? `SELECT VIEW_DEFINITION AS view_definition FROM information_schema.views WHERE table_schema = '${safeDbName}' AND table_name = '${safeName}' LIMIT 1`
|
||||
: '',
|
||||
`SHOW CREATE TABLE \`${name.replace(/`/g, '``')}\``,
|
||||
].filter(Boolean);
|
||||
case 'postgres':
|
||||
case 'kingbase':
|
||||
case 'highgo':
|
||||
case 'vastbase': {
|
||||
const schemaRef = schema || 'public';
|
||||
return [`SELECT pg_get_viewdef('${escapeSQLLiteral(schemaRef)}.${safeName}'::regclass, true) AS view_definition`];
|
||||
}
|
||||
case 'sqlserver':
|
||||
return [`SELECT OBJECT_DEFINITION(OBJECT_ID('${escapeSQLLiteral(viewName)}')) AS view_definition`];
|
||||
case 'oracle':
|
||||
case 'dm':
|
||||
if (schema) {
|
||||
return [`SELECT TEXT AS view_definition FROM ALL_VIEWS WHERE OWNER = '${escapeSQLLiteral(schema).toUpperCase()}' AND VIEW_NAME = '${safeName.toUpperCase()}'`];
|
||||
}
|
||||
if (safeDbName) {
|
||||
return [`SELECT TEXT AS view_definition FROM ALL_VIEWS WHERE OWNER = '${safeDbName.toUpperCase()}' AND VIEW_NAME = '${safeName.toUpperCase()}'`];
|
||||
}
|
||||
return [`SELECT TEXT AS view_definition FROM USER_VIEWS WHERE VIEW_NAME = '${safeName.toUpperCase()}'`];
|
||||
case 'sqlite':
|
||||
return [`SELECT sql AS view_definition FROM sqlite_master WHERE type='view' AND name='${safeName}'`];
|
||||
default:
|
||||
return [`-- 暂不支持该数据库类型的视图定义查看`];
|
||||
}
|
||||
};
|
||||
|
||||
const buildShowRoutineQueries = (dialect: string, routineName: string, routineType: string, dbName: string): string[] => {
|
||||
const { schema, name } = parseSchemaAndName(routineName);
|
||||
const safeName = escapeSQLLiteral(name);
|
||||
const safeDbName = escapeSQLLiteral(dbName);
|
||||
const upperType = (routineType || 'FUNCTION').toUpperCase();
|
||||
|
||||
switch (dialect) {
|
||||
case 'mysql':
|
||||
return [
|
||||
`SHOW CREATE ${upperType} \`${name.replace(/`/g, '``')}\``,
|
||||
safeDbName
|
||||
? `SELECT ROUTINE_DEFINITION AS routine_definition, ROUTINE_TYPE AS routine_type FROM information_schema.routines WHERE routine_schema = '${safeDbName}' AND routine_name = '${safeName}' LIMIT 1`
|
||||
: '',
|
||||
upperType === 'PROCEDURE'
|
||||
? `SHOW PROCEDURE STATUS LIKE '${safeName}'`
|
||||
: `SHOW FUNCTION STATUS LIKE '${safeName}'`,
|
||||
].filter(Boolean);
|
||||
case 'postgres':
|
||||
case 'kingbase':
|
||||
case 'highgo':
|
||||
case 'vastbase': {
|
||||
const schemaRef = schema || 'public';
|
||||
return [`SELECT pg_get_functiondef(p.oid) AS routine_definition FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid WHERE n.nspname = '${escapeSQLLiteral(schemaRef)}' AND p.proname = '${safeName}' LIMIT 1`];
|
||||
}
|
||||
case 'sqlserver':
|
||||
return [`SELECT OBJECT_DEFINITION(OBJECT_ID('${escapeSQLLiteral(routineName)}')) AS routine_definition`];
|
||||
case 'oracle':
|
||||
case 'dm': {
|
||||
const owner = schema ? escapeSQLLiteral(schema).toUpperCase() : (safeDbName ? safeDbName.toUpperCase() : '');
|
||||
if (owner) {
|
||||
return [`SELECT TEXT FROM ALL_SOURCE WHERE OWNER = '${owner}' AND NAME = '${safeName.toUpperCase()}' AND TYPE = '${upperType}' ORDER BY LINE`];
|
||||
}
|
||||
return [`SELECT TEXT FROM USER_SOURCE WHERE NAME = '${safeName.toUpperCase()}' AND TYPE = '${upperType}' ORDER BY LINE`];
|
||||
}
|
||||
case 'sqlite':
|
||||
return [`-- SQLite 不支持存储函数/存储过程`];
|
||||
default:
|
||||
return [`-- 暂不支持该数据库类型的函数/存储过程定义查看`];
|
||||
}
|
||||
};
|
||||
|
||||
const runQueryCandidates = async (
|
||||
config: Record<string, any>,
|
||||
dbName: string,
|
||||
queries: string[]
|
||||
): Promise<{ success: boolean; data: any[]; message?: string }> => {
|
||||
let lastMessage = '';
|
||||
let hasSuccessfulQuery = false;
|
||||
for (const query of queries) {
|
||||
const sql = String(query || '').trim();
|
||||
if (!sql) continue;
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, sql);
|
||||
if (!result.success || !Array.isArray(result.data)) {
|
||||
lastMessage = result.message || lastMessage;
|
||||
continue;
|
||||
}
|
||||
hasSuccessfulQuery = true;
|
||||
if (result.data.length > 0) {
|
||||
return { success: true, data: result.data };
|
||||
}
|
||||
} catch (error: any) {
|
||||
lastMessage = error?.message || String(error);
|
||||
}
|
||||
}
|
||||
if (hasSuccessfulQuery) {
|
||||
return { success: true, data: [] };
|
||||
}
|
||||
return { success: false, data: [], message: lastMessage };
|
||||
};
|
||||
|
||||
const getVersionHint = async (config: Record<string, any>, dbName: string): Promise<string> => {
|
||||
const candidates = [
|
||||
`SELECT VERSION() AS version`,
|
||||
`SHOW VARIABLES LIKE 'version'`,
|
||||
];
|
||||
for (const query of candidates) {
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, query);
|
||||
if (!result.success || !Array.isArray(result.data) || result.data.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const row = result.data[0] as Record<string, any>;
|
||||
const version =
|
||||
row.version
|
||||
|| row.VERSION
|
||||
|| row.Value
|
||||
|| row.value
|
||||
|| Object.values(row)[1]
|
||||
|| Object.values(row)[0];
|
||||
const text = String(version || '').trim();
|
||||
if (text) return text;
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const extractViewDefinition = (dialect: string, data: any[]): string => {
|
||||
if (!data || data.length === 0) return '-- 未找到视图定义';
|
||||
const row = data[0];
|
||||
|
||||
switch (dialect) {
|
||||
case 'mysql': {
|
||||
const keys = Object.keys(row);
|
||||
const textDefinition = row.view_definition || row.VIEW_DEFINITION;
|
||||
if (textDefinition) return String(textDefinition);
|
||||
const sqlKey = keys.find(k => k.toLowerCase().includes('create view') || k.toLowerCase() === 'create view');
|
||||
if (sqlKey) return row[sqlKey];
|
||||
const tableSqlKey = keys.find(k => k.toLowerCase().includes('create table'));
|
||||
if (tableSqlKey) return row[tableSqlKey];
|
||||
for (const key of keys) {
|
||||
const val = String(row[key] || '');
|
||||
if (val.toUpperCase().includes('CREATE') && (val.toUpperCase().includes('VIEW') || val.toUpperCase().includes('TABLE'))) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
return JSON.stringify(row, null, 2);
|
||||
}
|
||||
case 'oracle':
|
||||
case 'dm':
|
||||
return row.view_definition || row.VIEW_DEFINITION || row.text || row.TEXT || Object.values(row)[0] || '';
|
||||
default:
|
||||
return row.view_definition || row.VIEW_DEFINITION || row.sql || row.SQL || Object.values(row)[0] || '';
|
||||
}
|
||||
};
|
||||
|
||||
const extractRoutineDefinition = (dialect: string, data: any[]): string => {
|
||||
if (!data || data.length === 0) return '-- 未找到函数/存储过程定义';
|
||||
|
||||
switch (dialect) {
|
||||
case 'mysql': {
|
||||
const row = data[0];
|
||||
const keys = Object.keys(row);
|
||||
if (row.routine_definition || row.ROUTINE_DEFINITION) {
|
||||
return String(row.routine_definition || row.ROUTINE_DEFINITION);
|
||||
}
|
||||
const sqlKey = keys.find(k => k.toLowerCase().includes('create function') || k.toLowerCase().includes('create procedure'));
|
||||
if (sqlKey) return row[sqlKey];
|
||||
for (const key of keys) {
|
||||
const val = String(row[key] || '');
|
||||
if (val.toUpperCase().includes('CREATE') && (val.toUpperCase().includes('FUNCTION') || val.toUpperCase().includes('PROCEDURE'))) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
const routineName = String(row.Name || row.name || '').trim();
|
||||
if (routineName) {
|
||||
const routineType = String(row.Type || row.type || row.ROUTINE_TYPE || row.routine_type || 'FUNCTION').trim().toUpperCase();
|
||||
return `-- 当前数据源未返回可执行定义文本,已返回元数据\n-- 名称: ${routineName}\n-- 类型: ${routineType}\n${JSON.stringify(row, null, 2)}`;
|
||||
}
|
||||
return JSON.stringify(row, null, 2);
|
||||
}
|
||||
case 'oracle':
|
||||
case 'dm': {
|
||||
// Oracle/DM ALL_SOURCE returns multiple rows, one per line
|
||||
return data.map(row => row.text || row.TEXT || Object.values(row)[0] || '').join('');
|
||||
}
|
||||
default: {
|
||||
const row = data[0];
|
||||
return row.routine_definition || row.ROUTINE_DEFINITION || Object.values(row)[0] || '';
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const loadDefinition = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) {
|
||||
setError('未找到数据库连接');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const dbName = tab.dbName || '';
|
||||
const dialect = getMetadataDialect(conn);
|
||||
const sphinxLike = isSphinxConnection(conn) && dialect === 'mysql';
|
||||
|
||||
let queries: string[];
|
||||
let extractFn: (dialect: string, data: any[]) => string;
|
||||
let objectLabel: string;
|
||||
|
||||
if (tab.type === 'view-def') {
|
||||
const viewName = tab.viewName || '';
|
||||
if (!viewName) {
|
||||
setError('视图名称为空');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
queries = buildShowViewQueries(dialect, viewName, dbName);
|
||||
extractFn = extractViewDefinition;
|
||||
objectLabel = '视图';
|
||||
} else {
|
||||
const routineName = tab.routineName || '';
|
||||
const routineType = tab.routineType || 'FUNCTION';
|
||||
if (!routineName) {
|
||||
setError('函数/存储过程名称为空');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
queries = buildShowRoutineQueries(dialect, routineName, routineType, dbName);
|
||||
extractFn = extractRoutineDefinition;
|
||||
objectLabel = '函数/存储过程';
|
||||
}
|
||||
|
||||
if (!queries.length || String(queries[0] || '').startsWith('--')) {
|
||||
setDefinition(String(queries[0] || '-- 暂不支持该对象定义查看'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || '',
|
||||
database: conn.config.database || '',
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' }
|
||||
};
|
||||
|
||||
const result = await runQueryCandidates(config, dbName, queries);
|
||||
|
||||
if (result.success && Array.isArray(result.data) && result.data.length > 0) {
|
||||
const def = extractFn(dialect, result.data);
|
||||
setDefinition(def);
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.success) {
|
||||
if (sphinxLike) {
|
||||
const version = await getVersionHint(config, dbName);
|
||||
const versionText = version ? `(版本: ${version})` : '';
|
||||
setDefinition(`-- 当前 Sphinx 实例${versionText}未返回${objectLabel}定义。\n-- 已执行多套兼容查询,可能是版本能力限制或对象类型不支持。`);
|
||||
return;
|
||||
}
|
||||
setDefinition(`-- 未找到${objectLabel}定义`);
|
||||
} else if (sphinxLike) {
|
||||
const version = await getVersionHint(config, dbName);
|
||||
const versionText = version ? `(版本: ${version})` : '';
|
||||
setDefinition(`-- 当前 Sphinx 实例${versionText}不支持${objectLabel}定义查询。\n-- 已自动尝试兼容语句,返回失败信息: ${result.message || 'unknown error'}`);
|
||||
} else {
|
||||
setError(result.message || '查询定义失败');
|
||||
}
|
||||
} catch (e: any) {
|
||||
setError('查询定义失败: ' + (e?.message || String(e)));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadDefinition();
|
||||
}, [tab.connectionId, tab.dbName, tab.viewName, tab.routineName, tab.routineType, tab.type, connections]);
|
||||
|
||||
const objectLabel = tab.type === 'view-def' ? '视图' : '函数/存储过程';
|
||||
const objectName = tab.type === 'view-def' ? tab.viewName : tab.routineName;
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', alignItems: 'center', height: '100%' }}>
|
||||
<Spin tip={`加载${objectLabel}定义...`} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div style={{ padding: 16 }}>
|
||||
<Alert type="error" message="加载失败" description={error} showIcon />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
<div style={{ padding: '8px 16px', borderBottom: darkMode ? '1px solid #303030' : '1px solid #f0f0f0' }}>
|
||||
<strong>{objectLabel}: </strong>{objectName}
|
||||
{tab.dbName && <span style={{ marginLeft: 16, color: '#888' }}>数据库: {tab.dbName}</span>}
|
||||
{tab.routineType && <span style={{ marginLeft: 16, color: '#888' }}>类型: {tab.routineType}</span>}
|
||||
</div>
|
||||
<div style={{ flex: 1, minHeight: 0 }}>
|
||||
<Editor
|
||||
height="100%"
|
||||
language="sql"
|
||||
theme={darkMode ? 'transparent-dark' : 'transparent-light'}
|
||||
value={definition}
|
||||
options={{
|
||||
readOnly: true,
|
||||
minimap: { enabled: false },
|
||||
fontSize: 14,
|
||||
lineNumbers: 'on',
|
||||
scrollBeyondLastLine: false,
|
||||
wordWrap: 'on',
|
||||
automaticLayout: true,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default DefinitionViewer;
|
||||
250
frontend/src/components/ImportPreviewModal.tsx
Normal file
250
frontend/src/components/ImportPreviewModal.tsx
Normal file
@@ -0,0 +1,250 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal, Table, Alert, Progress, Button, Space } from 'antd';
|
||||
import { CheckCircleOutlined, CloseCircleOutlined } from '@ant-design/icons';
|
||||
import { PreviewImportFile, ImportDataWithProgress } from '../../wailsjs/go/app/App';
|
||||
import { EventsOn, EventsOff } from '../../wailsjs/runtime/runtime';
|
||||
import { useStore } from '../store';
|
||||
|
||||
interface ImportPreviewModalProps {
|
||||
visible: boolean;
|
||||
filePath: string;
|
||||
connectionId: string;
|
||||
dbName: string;
|
||||
tableName: string;
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
interface PreviewData {
|
||||
columns: string[];
|
||||
totalRows: number;
|
||||
previewRows: any[];
|
||||
}
|
||||
|
||||
interface ImportProgress {
|
||||
current: number;
|
||||
total: number;
|
||||
success: number;
|
||||
errors: number;
|
||||
}
|
||||
|
||||
const ImportPreviewModal: React.FC<ImportPreviewModalProps> = ({
|
||||
visible,
|
||||
filePath,
|
||||
connectionId,
|
||||
dbName,
|
||||
tableName,
|
||||
onClose,
|
||||
onSuccess
|
||||
}) => {
|
||||
const connections = useStore(state => state.connections);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [previewData, setPreviewData] = useState<PreviewData | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [importing, setImporting] = useState(false);
|
||||
const [progress, setProgress] = useState<ImportProgress | null>(null);
|
||||
const [importResult, setImportResult] = useState<any>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (visible && filePath) {
|
||||
loadPreview();
|
||||
}
|
||||
}, [visible, filePath]);
|
||||
|
||||
useEffect(() => {
|
||||
if (importing) {
|
||||
const unsubscribe = EventsOn('import:progress', (data: ImportProgress) => {
|
||||
setProgress(data);
|
||||
});
|
||||
return () => {
|
||||
EventsOff('import:progress');
|
||||
};
|
||||
}
|
||||
}, [importing]);
|
||||
|
||||
const loadPreview = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const res = await PreviewImportFile(filePath);
|
||||
if (res.success && res.data) {
|
||||
setPreviewData({
|
||||
columns: res.data.columns || [],
|
||||
totalRows: res.data.totalRows || 0,
|
||||
previewRows: res.data.previewRows || []
|
||||
});
|
||||
} else {
|
||||
setError(res.message || '预览失败');
|
||||
}
|
||||
} catch (e: any) {
|
||||
setError('预览失败: ' + e.message);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleImport = async () => {
|
||||
if (!previewData) return;
|
||||
|
||||
setImporting(true);
|
||||
setProgress({ current: 0, total: previewData.totalRows, success: 0, errors: 0 });
|
||||
setImportResult(null);
|
||||
|
||||
try {
|
||||
const conn = connections.find(c => c.id === connectionId);
|
||||
if (!conn) {
|
||||
setError('连接配置未找到');
|
||||
setImporting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || '',
|
||||
database: conn.config.database || '',
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' }
|
||||
};
|
||||
|
||||
const res = await ImportDataWithProgress(config as any, dbName, tableName, filePath);
|
||||
|
||||
if (res.success && res.data) {
|
||||
setImportResult(res.data);
|
||||
if (res.data.failed === 0) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
setError(res.message || '导入失败');
|
||||
}
|
||||
} catch (e: any) {
|
||||
setError('导入失败: ' + e.message);
|
||||
} finally {
|
||||
setImporting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const columns = previewData?.columns.map(col => ({
|
||||
title: col,
|
||||
dataIndex: col,
|
||||
key: col,
|
||||
ellipsis: true,
|
||||
width: 150
|
||||
})) || [];
|
||||
|
||||
const progressPercent = progress ? Math.round((progress.current / progress.total) * 100) : 0;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title="导入数据预览"
|
||||
open={visible}
|
||||
onCancel={onClose}
|
||||
width={900}
|
||||
footer={
|
||||
importResult ? (
|
||||
<Space>
|
||||
<Button onClick={onClose}>关闭</Button>
|
||||
</Space>
|
||||
) : importing ? null : (
|
||||
<Space>
|
||||
<Button onClick={onClose}>取消</Button>
|
||||
<Button
|
||||
type="primary"
|
||||
onClick={handleImport}
|
||||
disabled={!previewData || loading}
|
||||
>
|
||||
开始导入
|
||||
</Button>
|
||||
</Space>
|
||||
)
|
||||
}
|
||||
>
|
||||
{error && <Alert type="error" message={error} style={{ marginBottom: 16 }} showIcon />}
|
||||
|
||||
{loading && <div style={{ textAlign: 'center', padding: 40 }}>加载预览数据...</div>}
|
||||
|
||||
{!loading && previewData && !importing && !importResult && (
|
||||
<>
|
||||
<Alert
|
||||
type="info"
|
||||
message={`共 ${previewData.totalRows} 行数据,${previewData.columns.length} 个字段`}
|
||||
description='以下是前 5 行预览数据,确认无误后点击“开始导入”'
|
||||
style={{ marginBottom: 16 }}
|
||||
showIcon
|
||||
/>
|
||||
<div style={{ marginBottom: 8, fontWeight: 600 }}>字段列表:</div>
|
||||
<div style={{ marginBottom: 16, padding: 8, background: '#f5f5f5', borderRadius: 4 }}>
|
||||
{previewData.columns.join(', ')}
|
||||
</div>
|
||||
<div style={{ marginBottom: 8, fontWeight: 600 }}>数据预览(前 5 行):</div>
|
||||
<Table
|
||||
dataSource={previewData.previewRows}
|
||||
columns={columns}
|
||||
pagination={false}
|
||||
scroll={{ x: 'max-content' }}
|
||||
size="small"
|
||||
bordered
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{importing && progress && (
|
||||
<div style={{ padding: '40px 20px' }}>
|
||||
<div style={{ marginBottom: 16, fontSize: 16, fontWeight: 600, textAlign: 'center' }}>
|
||||
正在导入数据...
|
||||
</div>
|
||||
<Progress percent={progressPercent} status="active" />
|
||||
<div style={{ marginTop: 16, textAlign: 'center', color: '#666' }}>
|
||||
已处理 {progress.current} / {progress.total} 行
|
||||
<span style={{ marginLeft: 16, color: '#52c41a' }}>
|
||||
<CheckCircleOutlined /> 成功 {progress.success}
|
||||
</span>
|
||||
{progress.errors > 0 && (
|
||||
<span style={{ marginLeft: 16, color: '#ff4d4f' }}>
|
||||
<CloseCircleOutlined /> 失败 {progress.errors}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{importResult && (
|
||||
<div style={{ padding: 20 }}>
|
||||
<Alert
|
||||
type={importResult.failed === 0 ? 'success' : 'warning'}
|
||||
message="导入完成"
|
||||
description={
|
||||
<div>
|
||||
<div>成功导入 {importResult.success} 行</div>
|
||||
{importResult.failed > 0 && <div>失败 {importResult.failed} 行</div>}
|
||||
</div>
|
||||
}
|
||||
showIcon
|
||||
style={{ marginBottom: 16 }}
|
||||
/>
|
||||
{importResult.errorLogs && importResult.errorLogs.length > 0 && (
|
||||
<>
|
||||
<div style={{ marginBottom: 8, fontWeight: 600, color: '#ff4d4f' }}>错误日志:</div>
|
||||
<div style={{
|
||||
maxHeight: 300,
|
||||
overflow: 'auto',
|
||||
background: '#fff1f0',
|
||||
border: '1px solid #ffccc7',
|
||||
borderRadius: 4,
|
||||
padding: 12,
|
||||
fontSize: 12,
|
||||
fontFamily: 'monospace'
|
||||
}}>
|
||||
{importResult.errorLogs.map((log: string, idx: number) => (
|
||||
<div key={idx} style={{ marginBottom: 4 }}>{log}</div>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImportPreviewModal;
|
||||
@@ -2,7 +2,7 @@ import React, { useRef, useEffect } from 'react';
|
||||
import { Table, Tag, Button, Tooltip } from 'antd';
|
||||
import { ClearOutlined, CloseOutlined, CaretRightOutlined, BugOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform } from '../utils/appearance';
|
||||
import { normalizeOpacityForPlatform } from '../utils/appearance';
|
||||
|
||||
interface LogPanelProps {
|
||||
height: number;
|
||||
@@ -17,7 +17,6 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
const appearance = useStore(state => state.appearance);
|
||||
const darkMode = theme === 'dark';
|
||||
const opacity = normalizeOpacityForPlatform(appearance.opacity);
|
||||
const blur = normalizeBlurForPlatform(appearance.blur);
|
||||
|
||||
// Background Helper
|
||||
const getBg = (darkHex: string) => {
|
||||
@@ -30,7 +29,8 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
};
|
||||
const bgMain = getBg('#1f1f1f');
|
||||
const bgToolbar = getBg('#2a2a2a');
|
||||
const blurFilter = blurToFilter(blur);
|
||||
const logScrollbarThumb = darkMode ? 'rgba(255, 255, 255, 0.34)' : 'rgba(0, 0, 0, 0.26)';
|
||||
const logScrollbarThumbHover = darkMode ? 'rgba(255, 255, 255, 0.5)' : 'rgba(0, 0, 0, 0.36)';
|
||||
|
||||
const columns = [
|
||||
{
|
||||
@@ -73,8 +73,6 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
height,
|
||||
borderTop: 'none',
|
||||
background: bgMain,
|
||||
backdropFilter: blurFilter,
|
||||
WebkitBackdropFilter: blurFilter,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
position: 'relative',
|
||||
@@ -117,8 +115,9 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
</div>
|
||||
|
||||
{/* List */}
|
||||
<div style={{ flex: 1, overflow: 'auto' }}>
|
||||
<div className="log-panel-scroll" style={{ flex: 1, overflow: 'auto' }}>
|
||||
<Table
|
||||
className="log-panel-table"
|
||||
dataSource={sqlLogs}
|
||||
columns={columns}
|
||||
size="small"
|
||||
@@ -128,6 +127,35 @@ const LogPanel: React.FC<LogPanelProps> = ({ height, onClose, onResizeStart }) =
|
||||
// scroll={{ y: height - 32 }} // Let flex handle it
|
||||
/>
|
||||
</div>
|
||||
<style>{`
|
||||
.log-panel-scroll {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: ${logScrollbarThumb} transparent;
|
||||
}
|
||||
.log-panel-scroll::-webkit-scrollbar {
|
||||
width: 10px;
|
||||
height: 10px;
|
||||
}
|
||||
.log-panel-scroll::-webkit-scrollbar-track,
|
||||
.log-panel-scroll::-webkit-scrollbar-corner {
|
||||
background: transparent;
|
||||
}
|
||||
.log-panel-scroll::-webkit-scrollbar-thumb {
|
||||
background: ${logScrollbarThumb};
|
||||
border-radius: 8px;
|
||||
border: 2px solid transparent;
|
||||
background-clip: padding-box;
|
||||
}
|
||||
.log-panel-scroll::-webkit-scrollbar-thumb:hover {
|
||||
background: ${logScrollbarThumbHover};
|
||||
background-clip: padding-box;
|
||||
}
|
||||
.log-panel-table .ant-table,
|
||||
.log-panel-table .ant-table-container,
|
||||
.log-panel-table .ant-table-tbody > tr > td {
|
||||
background: transparent !important;
|
||||
}
|
||||
`}</style>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -196,6 +196,9 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
editorRef.current = editor;
|
||||
monacoRef.current = monaco;
|
||||
|
||||
// 应用透明主题(主题已在 main.tsx 全局注册)
|
||||
monaco.editor.setTheme(darkMode ? 'transparent-dark' : 'transparent-light');
|
||||
|
||||
monaco.languages.registerCompletionItemProvider('sql', {
|
||||
triggerCharacters: ['.'],
|
||||
provideCompletionItems: async (model: any, position: any) => {
|
||||
@@ -919,7 +922,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const applyAutoLimit = (sql: string, dbType: string, maxRows: number): { sql: string; applied: boolean; maxRows: number } => {
|
||||
const normalizedType = (dbType || 'mysql').toLowerCase();
|
||||
const supportsLimit = normalizedType === 'mysql' || normalizedType === 'postgres' || normalizedType === 'kingbase' || normalizedType === 'sqlite' || normalizedType === '';
|
||||
const supportsLimit = normalizedType === 'mysql' || normalizedType === 'mariadb' || normalizedType === 'sphinx' || normalizedType === 'postgres' || normalizedType === 'kingbase' || normalizedType === 'sqlite' || normalizedType === 'tdengine' || normalizedType === '';
|
||||
if (!supportsLimit) return { sql, applied: false, maxRows };
|
||||
if (!Number.isFinite(maxRows) || maxRows <= 0) return { sql, applied: false, maxRows };
|
||||
|
||||
@@ -997,6 +1000,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const nextResultSets: ResultSet[] = [];
|
||||
const maxRows = Number(queryOptions?.maxRows) || 0;
|
||||
const dbType = String((config as any).type || 'mysql');
|
||||
const normalizedDbType = dbType.toLowerCase();
|
||||
const forceReadOnlyResult = normalizedDbType === 'tdengine';
|
||||
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
|
||||
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
|
||||
let anyTruncated = false;
|
||||
@@ -1053,7 +1058,9 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const tableMatch = rawStatement.match(/^\s*SELECT\s+\*\s+FROM\s+[`"]?(\w+)[`"]?\s*(?:WHERE.*)?(?:ORDER BY.*)?(?:LIMIT.*)?$/i);
|
||||
if (tableMatch) {
|
||||
simpleTableName = tableMatch[1];
|
||||
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
|
||||
if (!forceReadOnlyResult) {
|
||||
pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName });
|
||||
}
|
||||
}
|
||||
|
||||
nextResultSets.push({
|
||||
@@ -1207,7 +1214,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
transition: none !important;
|
||||
}
|
||||
`}</style>
|
||||
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: '8px', flexShrink: 0, alignItems: 'center' }}>
|
||||
<div style={{ padding: '8px', display: 'flex', gap: '8px', flexShrink: 0, alignItems: 'center' }}>
|
||||
<Select
|
||||
style={{ width: 150 }}
|
||||
placeholder="选择连接"
|
||||
@@ -1261,11 +1268,11 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
</Button.Group>
|
||||
</div>
|
||||
|
||||
<div style={{ height: editorHeight, minHeight: '100px', borderBottom: '1px solid #eee' }}>
|
||||
<div style={{ height: editorHeight, minHeight: '100px' }}>
|
||||
<Editor
|
||||
height="100%"
|
||||
defaultLanguage="sql"
|
||||
theme={darkMode ? "vs-dark" : "light"}
|
||||
theme={darkMode ? "transparent-dark" : "transparent-light"}
|
||||
value={query}
|
||||
onChange={(val) => setQuery(val || '')}
|
||||
onMount={handleEditorDidMount}
|
||||
@@ -1283,7 +1290,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
style={{
|
||||
height: '5px',
|
||||
cursor: 'row-resize',
|
||||
background: darkMode ? '#333' : '#f0f0f0',
|
||||
background: darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.04)',
|
||||
flexShrink: 0,
|
||||
zIndex: 10
|
||||
}}
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
import React, { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { Table, Input, Button, Space, Tag, message, Modal, Form, InputNumber, Popconfirm, Tooltip, Radio } from 'antd';
|
||||
import { ReloadOutlined, DeleteOutlined, PlusOutlined, EditOutlined, SearchOutlined, ClockCircleOutlined, CopyOutlined } from '@ant-design/icons';
|
||||
import React, { useState, useEffect, useCallback, useMemo, useRef } from 'react';
|
||||
import { Table, Input, Button, Space, Tag, Tree, Spin, message, Modal, Form, InputNumber, Popconfirm, Tooltip, Radio } from 'antd';
|
||||
import { ReloadOutlined, DeleteOutlined, PlusOutlined, EditOutlined, SearchOutlined, ClockCircleOutlined, CopyOutlined, FolderOpenOutlined, KeyOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { RedisKeyInfo, RedisValue } from '../types';
|
||||
import Editor from '@monaco-editor/react';
|
||||
import type { ColumnType } from 'antd/es/table';
|
||||
import type { DataNode } from 'antd/es/tree';
|
||||
|
||||
const { Search } = Input;
|
||||
|
||||
const KEY_GROUP_DELIMITER = ':';
|
||||
const EMPTY_SEGMENT_LABEL = '(empty)';
|
||||
const REDIS_TREE_KEY_TYPE_WIDTH = 92;
|
||||
const REDIS_TREE_KEY_TYPE_WIDTH_NARROW = 84;
|
||||
const REDIS_TREE_KEY_TTL_WIDTH = 92;
|
||||
const REDIS_TREE_HIDE_TTL_THRESHOLD = 460;
|
||||
|
||||
interface RedisViewerProps {
|
||||
connectionId: string;
|
||||
redisDB: number;
|
||||
@@ -222,86 +229,186 @@ const ResizableDivider: React.FC<{
|
||||
};
|
||||
|
||||
// 可拖拽列头组件 - 纯 DOM 操作实现
|
||||
const ResizableTitle: React.FC<any> = (props) => {
|
||||
const { onResize, width, children, ...restProps } = props;
|
||||
const thRef = useRef<HTMLTableCellElement>(null);
|
||||
type RedisKeyTreeLeaf = {
|
||||
keyInfo: RedisKeyInfo;
|
||||
label: string;
|
||||
};
|
||||
|
||||
// 如果没有 onResize 或 width,说明这列不需要拖拽(如复选框列)
|
||||
if (!onResize || !width) {
|
||||
return <th {...restProps}>{children}</th>;
|
||||
}
|
||||
type RedisKeyTreeGroup = {
|
||||
name: string;
|
||||
path: string;
|
||||
children: Map<string, RedisKeyTreeGroup>;
|
||||
leaves: RedisKeyTreeLeaf[];
|
||||
};
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
type RedisKeyTreeResult = {
|
||||
treeData: DataNode[];
|
||||
rawKeyByNodeKey: Map<string, string>;
|
||||
leafNodeKeyByRawKey: Map<string, string>;
|
||||
groupKeys: string[];
|
||||
};
|
||||
|
||||
const startX = e.clientX;
|
||||
const startWidth = width;
|
||||
const th = thRef.current;
|
||||
if (!th) return;
|
||||
const normalizeKeySegment = (segment: string): string => {
|
||||
return segment === '' ? EMPTY_SEGMENT_LABEL : segment;
|
||||
};
|
||||
|
||||
// 找到对应的 colgroup col 元素来同步更新列宽
|
||||
const table = th.closest('table');
|
||||
const thIndex = Array.from(th.parentElement?.children || []).indexOf(th);
|
||||
const col = table?.querySelector(`colgroup col:nth-child(${thIndex + 1})`) as HTMLElement | null;
|
||||
const createTreeGroup = (name: string, path: string): RedisKeyTreeGroup => {
|
||||
return { name, path, children: new Map(), leaves: [] };
|
||||
};
|
||||
|
||||
// 创建遮罩层防止文本选择
|
||||
const overlay = document.createElement('div');
|
||||
overlay.style.cssText = 'position:fixed;top:0;left:0;right:0;bottom:0;cursor:col-resize;z-index:9999;';
|
||||
document.body.appendChild(overlay);
|
||||
const countGroupLeafNodes = (group: RedisKeyTreeGroup): number => {
|
||||
let count = group.leaves.length;
|
||||
group.children.forEach((child) => {
|
||||
count += countGroupLeafNodes(child);
|
||||
});
|
||||
return count;
|
||||
};
|
||||
|
||||
let currentWidth = startWidth;
|
||||
const buildRedisKeyTree = (
|
||||
keys: RedisKeyInfo[],
|
||||
formatTTL: (ttl: number) => string,
|
||||
getTypeColor: (type: string) => string,
|
||||
showTTL: boolean
|
||||
): RedisKeyTreeResult => {
|
||||
const root = createTreeGroup('__root__', '__root__');
|
||||
|
||||
const handleMouseMove = (moveEvent: MouseEvent) => {
|
||||
moveEvent.preventDefault();
|
||||
const delta = moveEvent.clientX - startX;
|
||||
currentWidth = Math.max(50, startWidth + delta);
|
||||
// 直接操作 DOM
|
||||
th.style.width = `${currentWidth}px`;
|
||||
if (col) {
|
||||
col.style.width = `${currentWidth}px`;
|
||||
keys.forEach((keyInfo) => {
|
||||
const segments = keyInfo.key.split(KEY_GROUP_DELIMITER);
|
||||
if (segments.length <= 1) {
|
||||
root.leaves.push({ keyInfo, label: keyInfo.key });
|
||||
return;
|
||||
}
|
||||
|
||||
const groupSegments = segments.slice(0, -1);
|
||||
const leafLabel = normalizeKeySegment(segments[segments.length - 1]);
|
||||
let current = root;
|
||||
const pathParts: string[] = [];
|
||||
|
||||
groupSegments.forEach((segment) => {
|
||||
const normalized = normalizeKeySegment(segment);
|
||||
pathParts.push(normalized);
|
||||
const groupPath = pathParts.join(KEY_GROUP_DELIMITER);
|
||||
let child = current.children.get(normalized);
|
||||
if (!child) {
|
||||
child = createTreeGroup(normalized, groupPath);
|
||||
current.children.set(normalized, child);
|
||||
}
|
||||
};
|
||||
current = child;
|
||||
});
|
||||
|
||||
const handleMouseUp = () => {
|
||||
document.removeEventListener('mousemove', handleMouseMove);
|
||||
document.removeEventListener('mouseup', handleMouseUp);
|
||||
document.body.removeChild(overlay);
|
||||
// 拖拽结束时更新 React state
|
||||
onResize(null, { size: { width: currentWidth } });
|
||||
};
|
||||
current.leaves.push({ keyInfo, label: leafLabel });
|
||||
});
|
||||
|
||||
document.addEventListener('mousemove', handleMouseMove);
|
||||
document.addEventListener('mouseup', handleMouseUp);
|
||||
const rawKeyByNodeKey = new Map<string, string>();
|
||||
const leafNodeKeyByRawKey = new Map<string, string>();
|
||||
const groupKeys: string[] = [];
|
||||
|
||||
const toTreeNodes = (group: RedisKeyTreeGroup): DataNode[] => {
|
||||
const childGroups = Array.from(group.children.values()).sort((a, b) => a.name.localeCompare(b.name));
|
||||
const childLeaves = [...group.leaves].sort((a, b) => a.keyInfo.key.localeCompare(b.keyInfo.key));
|
||||
|
||||
const groupNodes: DataNode[] = childGroups.map((child) => {
|
||||
const groupNodeKey = `group:${child.path}`;
|
||||
groupKeys.push(groupNodeKey);
|
||||
return {
|
||||
key: groupNodeKey,
|
||||
title: (
|
||||
<Space size={6}>
|
||||
<FolderOpenOutlined style={{ color: '#8c8c8c' }} />
|
||||
<span>{child.name}</span>
|
||||
<span style={{ fontSize: 12, color: '#999' }}>({countGroupLeafNodes(child)})</span>
|
||||
</Space>
|
||||
),
|
||||
selectable: false,
|
||||
disableCheckbox: true,
|
||||
children: toTreeNodes(child),
|
||||
};
|
||||
});
|
||||
|
||||
const leafNodes: DataNode[] = childLeaves.map((leaf) => {
|
||||
const nodeKey = `key:${leaf.keyInfo.key}`;
|
||||
rawKeyByNodeKey.set(nodeKey, leaf.keyInfo.key);
|
||||
leafNodeKeyByRawKey.set(leaf.keyInfo.key, nodeKey);
|
||||
return {
|
||||
key: nodeKey,
|
||||
isLeaf: true,
|
||||
title: (
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 8,
|
||||
minWidth: 0,
|
||||
width: '100%',
|
||||
overflow: 'hidden',
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 6,
|
||||
minWidth: 0,
|
||||
flex: 1,
|
||||
overflow: 'hidden',
|
||||
}}
|
||||
>
|
||||
<KeyOutlined style={{ color: '#1677ff', flexShrink: 0 }} />
|
||||
<Tooltip title={leaf.keyInfo.key}>
|
||||
<span
|
||||
style={{
|
||||
minWidth: 0,
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
whiteSpace: 'nowrap',
|
||||
display: 'block',
|
||||
}}
|
||||
>
|
||||
{leaf.label}
|
||||
</span>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<Tag
|
||||
color={getTypeColor(leaf.keyInfo.type)}
|
||||
style={{
|
||||
marginInlineEnd: 0,
|
||||
width: showTTL ? REDIS_TREE_KEY_TYPE_WIDTH : REDIS_TREE_KEY_TYPE_WIDTH_NARROW,
|
||||
textAlign: 'center',
|
||||
flexShrink: 0
|
||||
}}
|
||||
>
|
||||
{leaf.keyInfo.type}
|
||||
</Tag>
|
||||
{showTTL && (
|
||||
<span
|
||||
style={{
|
||||
width: REDIS_TREE_KEY_TTL_WIDTH,
|
||||
fontSize: 12,
|
||||
color: '#999',
|
||||
textAlign: 'left',
|
||||
whiteSpace: 'nowrap',
|
||||
flexShrink: 0,
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
}}
|
||||
>
|
||||
{formatTTL(leaf.keyInfo.ttl)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
});
|
||||
|
||||
return [...groupNodes, ...leafNodes];
|
||||
};
|
||||
|
||||
return (
|
||||
<th
|
||||
ref={thRef}
|
||||
{...restProps}
|
||||
style={{
|
||||
...restProps.style,
|
||||
position: 'relative'
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
right: 0,
|
||||
top: 0,
|
||||
bottom: 0,
|
||||
width: 10,
|
||||
cursor: 'col-resize',
|
||||
zIndex: 1,
|
||||
background: 'transparent'
|
||||
}}
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseOver={(e) => { e.currentTarget.style.background = 'rgba(0,0,0,0.06)'; }}
|
||||
onMouseOut={(e) => { e.currentTarget.style.background = 'transparent'; }}
|
||||
/>
|
||||
</th>
|
||||
);
|
||||
return {
|
||||
treeData: toTreeNodes(root),
|
||||
rawKeyByNodeKey,
|
||||
leafNodeKeyByRawKey,
|
||||
groupKeys,
|
||||
};
|
||||
};
|
||||
|
||||
const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
@@ -317,7 +424,6 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const [keyValue, setKeyValue] = useState<RedisValue | null>(null);
|
||||
const [valueLoading, setValueLoading] = useState(false);
|
||||
const [editModalOpen, setEditModalOpen] = useState(false);
|
||||
const [editForm] = Form.useForm();
|
||||
const [newKeyModalOpen, setNewKeyModalOpen] = useState(false);
|
||||
const [newKeyForm] = Form.useForm();
|
||||
const [ttlModalOpen, setTtlModalOpen] = useState(false);
|
||||
@@ -341,15 +447,8 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
// 面板宽度状态和 ref - 默认占据 50% 宽度
|
||||
const [leftPanelWidth, setLeftPanelWidth] = useState<number | string>('50%');
|
||||
const leftPanelRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// 列宽状态 - 复选框列约 32px,总宽度需要接近面板宽度
|
||||
// Key 列自适应剩余空间,其他列固定宽度
|
||||
const [columnWidths, setColumnWidths] = useState({
|
||||
key: 220, // Key 名称,需要较宽
|
||||
type: 65, // 类型标签
|
||||
ttl: 80, // TTL 显示
|
||||
action: 50 // 操作按钮
|
||||
});
|
||||
const [showTreeKeyTTL, setShowTreeKeyTTL] = useState(true);
|
||||
const [expandedGroupKeys, setExpandedGroupKeys] = useState<string[]>([]);
|
||||
|
||||
const getConfig = useCallback(() => {
|
||||
if (!connection) return null;
|
||||
@@ -373,7 +472,12 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
if (res.success) {
|
||||
const result = res.data;
|
||||
if (append) {
|
||||
setKeys(prev => [...prev, ...result.keys]);
|
||||
setKeys(prev => {
|
||||
const keyMap = new Map<string, RedisKeyInfo>();
|
||||
prev.forEach(item => keyMap.set(item.key, item));
|
||||
result.keys.forEach((item: RedisKeyInfo) => keyMap.set(item.key, item));
|
||||
return Array.from(keyMap.values());
|
||||
});
|
||||
} else {
|
||||
setKeys(result.keys);
|
||||
}
|
||||
@@ -451,6 +555,11 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeleteCurrentKey = async () => {
|
||||
if (!selectedKey) return;
|
||||
await handleDeleteKeys([selectedKey]);
|
||||
};
|
||||
|
||||
const handleSetTTL = async () => {
|
||||
const config = getConfig();
|
||||
if (!config || !selectedKey) return;
|
||||
@@ -529,65 +638,81 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
return `${Math.floor(ttl / 86400)}天${Math.floor((ttl % 86400) / 3600)}时`;
|
||||
};
|
||||
|
||||
// 处理列宽调整 - react-resizable 的 onResize 回调格式
|
||||
const handleColumnResize = (key: string) => (_e: any, { size }: { size: { width: number } }) => {
|
||||
setColumnWidths(prev => ({ ...prev, [key]: size.width }));
|
||||
useEffect(() => {
|
||||
const target = leftPanelRef.current;
|
||||
if (!target) return;
|
||||
|
||||
const updateTTLVisibility = (width: number) => {
|
||||
const nextShowTTL = width > REDIS_TREE_HIDE_TTL_THRESHOLD;
|
||||
setShowTreeKeyTTL((prev) => (prev === nextShowTTL ? prev : nextShowTTL));
|
||||
};
|
||||
|
||||
updateTTLVisibility(Math.round(target.getBoundingClientRect().width));
|
||||
|
||||
if (typeof ResizeObserver !== 'undefined') {
|
||||
const observer = new ResizeObserver((entries) => {
|
||||
const width = Math.round(entries[0]?.contentRect.width || target.getBoundingClientRect().width);
|
||||
updateTTLVisibility(width);
|
||||
});
|
||||
observer.observe(target);
|
||||
return () => observer.disconnect();
|
||||
}
|
||||
|
||||
const handleWindowResize = () => {
|
||||
updateTTLVisibility(Math.round(target.getBoundingClientRect().width));
|
||||
};
|
||||
window.addEventListener('resize', handleWindowResize);
|
||||
return () => window.removeEventListener('resize', handleWindowResize);
|
||||
}, []);
|
||||
|
||||
const keyTree = useMemo(() => {
|
||||
return buildRedisKeyTree(keys, formatTTL, getTypeColor, showTreeKeyTTL);
|
||||
}, [keys, showTreeKeyTTL]);
|
||||
|
||||
const selectedTreeNodeKeys = useMemo(() => {
|
||||
if (!selectedKey) {
|
||||
return [] as string[];
|
||||
}
|
||||
const nodeKey = keyTree.leafNodeKeyByRawKey.get(selectedKey);
|
||||
return nodeKey ? [nodeKey] : [];
|
||||
}, [selectedKey, keyTree]);
|
||||
|
||||
const checkedTreeNodeKeys = useMemo(() => {
|
||||
return selectedKeys
|
||||
.map(rawKey => keyTree.leafNodeKeyByRawKey.get(rawKey))
|
||||
.filter((nodeKey): nodeKey is string => Boolean(nodeKey));
|
||||
}, [selectedKeys, keyTree]);
|
||||
|
||||
useEffect(() => {
|
||||
const existingKeySet = new Set(keys.map(item => item.key));
|
||||
setSelectedKeys(prev => prev.filter(rawKey => existingKeySet.has(rawKey)));
|
||||
}, [keys]);
|
||||
|
||||
useEffect(() => {
|
||||
setExpandedGroupKeys((prev) => {
|
||||
const validKeys = prev.filter(nodeKey => keyTree.groupKeys.includes(nodeKey));
|
||||
return validKeys;
|
||||
});
|
||||
}, [keyTree]);
|
||||
|
||||
const handleTreeSelect = (nodeKeys: React.Key[]) => {
|
||||
if (nodeKeys.length === 0) {
|
||||
return;
|
||||
}
|
||||
const rawKey = keyTree.rawKeyByNodeKey.get(String(nodeKeys[0]));
|
||||
if (!rawKey) {
|
||||
return;
|
||||
}
|
||||
loadKeyValue(rawKey);
|
||||
};
|
||||
|
||||
const columns: ColumnType<RedisKeyInfo>[] = [
|
||||
{
|
||||
title: 'Key',
|
||||
dataIndex: 'key',
|
||||
key: 'key',
|
||||
width: columnWidths.key,
|
||||
ellipsis: true,
|
||||
onHeaderCell: (column: any) => ({
|
||||
width: column.width,
|
||||
onResize: handleColumnResize('key')
|
||||
}),
|
||||
render: (text: string) => (
|
||||
<Tooltip title={text}>
|
||||
<span style={{ cursor: 'pointer' }} onClick={() => loadKeyValue(text)}>{text}</span>
|
||||
</Tooltip>
|
||||
)
|
||||
},
|
||||
{
|
||||
title: '类型',
|
||||
dataIndex: 'type',
|
||||
key: 'type',
|
||||
width: columnWidths.type,
|
||||
onHeaderCell: (column: any) => ({
|
||||
width: column.width,
|
||||
onResize: handleColumnResize('type')
|
||||
}),
|
||||
render: (type: string) => <Tag color={getTypeColor(type)}>{type}</Tag>
|
||||
},
|
||||
{
|
||||
title: 'TTL',
|
||||
dataIndex: 'ttl',
|
||||
key: 'ttl',
|
||||
width: columnWidths.ttl,
|
||||
onHeaderCell: (column: any) => ({
|
||||
width: column.width,
|
||||
onResize: handleColumnResize('ttl')
|
||||
}),
|
||||
render: (ttl: number) => formatTTL(ttl)
|
||||
},
|
||||
{
|
||||
title: '操作',
|
||||
key: 'action',
|
||||
width: columnWidths.action,
|
||||
onHeaderCell: (column: any) => ({
|
||||
width: column.width,
|
||||
onResize: handleColumnResize('action')
|
||||
}),
|
||||
render: (_: any, record: RedisKeyInfo) => (
|
||||
<Popconfirm title="确定删除此 Key?" onConfirm={() => handleDeleteKeys([record.key])}>
|
||||
<Button type="text" danger size="small" icon={<DeleteOutlined />} />
|
||||
</Popconfirm>
|
||||
)
|
||||
}
|
||||
];
|
||||
const handleTreeCheck = (checked: React.Key[] | { checked: React.Key[]; halfChecked: React.Key[] }) => {
|
||||
const checkedNodeKeys = Array.isArray(checked) ? checked : checked.checked;
|
||||
const rawKeys = checkedNodeKeys
|
||||
.map(nodeKey => keyTree.rawKeyByNodeKey.get(String(nodeKey)))
|
||||
.filter((rawKey): rawKey is string => Boolean(rawKey));
|
||||
setSelectedKeys(rawKeys);
|
||||
};
|
||||
|
||||
const renderValueEditor = () => {
|
||||
if (!keyValue || !selectedKey) {
|
||||
@@ -1375,6 +1500,9 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
setTtlModalOpen(true);
|
||||
}}>设置 TTL</Button>
|
||||
<Button size="small" onClick={() => loadKeyValue(selectedKey)} icon={<ReloadOutlined />}>刷新</Button>
|
||||
<Popconfirm title={`确定删除 Key "${selectedKey}"?`} onConfirm={handleDeleteCurrentKey}>
|
||||
<Button size="small" danger icon={<DeleteOutlined />}>删除 Key</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
</div>
|
||||
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden' }}>
|
||||
@@ -1410,36 +1538,35 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
<Button size="small" icon={<ReloadOutlined />} onClick={handleRefresh}>刷新</Button>
|
||||
<Button size="small" icon={<PlusOutlined />} onClick={() => setNewKeyModalOpen(true)}>新建</Button>
|
||||
</Space>
|
||||
{selectedKeys.length > 0 && (
|
||||
<Popconfirm title={`确定删除选中的 ${selectedKeys.length} 个 Key?`} onConfirm={() => handleDeleteKeys(selectedKeys)}>
|
||||
<Button size="small" danger icon={<DeleteOutlined />}>删除选中</Button>
|
||||
</Popconfirm>
|
||||
)}
|
||||
<Popconfirm
|
||||
title={`确定删除选中的 ${selectedKeys.length} 个 Key?`}
|
||||
onConfirm={() => handleDeleteKeys(selectedKeys)}
|
||||
disabled={selectedKeys.length === 0}
|
||||
>
|
||||
<Button size="small" danger icon={<DeleteOutlined />} disabled={selectedKeys.length === 0}>
|
||||
删除选中({selectedKeys.length})
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ flex: 1, overflow: 'auto' }}>
|
||||
<Table
|
||||
dataSource={keys}
|
||||
columns={columns}
|
||||
rowKey="key"
|
||||
size="small"
|
||||
loading={loading}
|
||||
pagination={false}
|
||||
components={{
|
||||
header: {
|
||||
cell: ResizableTitle
|
||||
}
|
||||
}}
|
||||
rowSelection={{
|
||||
selectedRowKeys: selectedKeys,
|
||||
onChange: (keys) => setSelectedKeys(keys as string[])
|
||||
}}
|
||||
onRow={(record) => ({
|
||||
onClick: () => loadKeyValue(record.key),
|
||||
style: { cursor: 'pointer', background: selectedKey === record.key ? '#e6f7ff' : undefined }
|
||||
})}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
<Spin spinning={loading} size="small">
|
||||
<Tree
|
||||
blockNode
|
||||
showIcon={false}
|
||||
checkable
|
||||
checkStrictly
|
||||
selectable
|
||||
treeData={keyTree.treeData}
|
||||
selectedKeys={selectedTreeNodeKeys}
|
||||
checkedKeys={checkedTreeNodeKeys}
|
||||
expandedKeys={expandedGroupKeys}
|
||||
onExpand={(nextExpandedKeys) => setExpandedGroupKeys(nextExpandedKeys as string[])}
|
||||
onSelect={(nodeKeys) => handleTreeSelect(nodeKeys)}
|
||||
onCheck={(checked) => handleTreeCheck(checked)}
|
||||
style={{ padding: '8px 6px' }}
|
||||
/>
|
||||
</Spin>
|
||||
{hasMore && (
|
||||
<div style={{ padding: 8, textAlign: 'center' }}>
|
||||
<Button onClick={handleLoadMore} loading={loading}>加载更多</Button>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,9 +7,31 @@ import QueryEditor from './QueryEditor';
|
||||
import TableDesigner from './TableDesigner';
|
||||
import RedisViewer from './RedisViewer';
|
||||
import RedisCommandEditor from './RedisCommandEditor';
|
||||
import TriggerViewer from './TriggerViewer';
|
||||
import DefinitionViewer from './DefinitionViewer';
|
||||
import type { TabData } from '../types';
|
||||
|
||||
const detectConnectionEnvLabel = (connectionName: string): string | null => {
|
||||
const tokens = connectionName.toLowerCase().split(/[^a-z0-9]+/).filter(Boolean);
|
||||
if (tokens.includes('prod') || tokens.includes('production')) return 'PROD';
|
||||
if (tokens.includes('uat')) return 'UAT';
|
||||
if (tokens.includes('dev') || tokens.includes('development')) return 'DEV';
|
||||
if (tokens.includes('sit')) return 'SIT';
|
||||
if (tokens.includes('stg') || tokens.includes('stage') || tokens.includes('staging') || tokens.includes('pre')) return 'STG';
|
||||
if (tokens.includes('test') || tokens.includes('qa')) return 'TEST';
|
||||
return null;
|
||||
};
|
||||
|
||||
const buildTabDisplayTitle = (tab: TabData, connectionName: string | undefined): string => {
|
||||
if (tab.type !== 'table' && tab.type !== 'design') return tab.title;
|
||||
if (!connectionName) return tab.title;
|
||||
const prefix = detectConnectionEnvLabel(connectionName) || connectionName;
|
||||
return `[${prefix}] ${tab.title}`;
|
||||
};
|
||||
|
||||
const TabManager: React.FC = () => {
|
||||
const tabs = useStore(state => state.tabs);
|
||||
const connections = useStore(state => state.connections);
|
||||
const activeTabId = useStore(state => state.activeTabId);
|
||||
const setActiveTab = useStore(state => state.setActiveTab);
|
||||
const closeTab = useStore(state => state.closeTab);
|
||||
@@ -29,6 +51,8 @@ const TabManager: React.FC = () => {
|
||||
};
|
||||
|
||||
const items = useMemo(() => tabs.map((tab, index) => {
|
||||
const connectionName = connections.find((conn) => conn.id === tab.connectionId)?.name;
|
||||
const displayTitle = buildTabDisplayTitle(tab, connectionName);
|
||||
let content;
|
||||
if (tab.type === 'query') {
|
||||
content = <QueryEditor tab={tab} />;
|
||||
@@ -40,6 +64,10 @@ const TabManager: React.FC = () => {
|
||||
content = <RedisViewer connectionId={tab.connectionId} redisDB={tab.redisDB ?? 0} />;
|
||||
} else if (tab.type === 'redis-command') {
|
||||
content = <RedisCommandEditor connectionId={tab.connectionId} redisDB={tab.redisDB ?? 0} />;
|
||||
} else if (tab.type === 'trigger') {
|
||||
content = <TriggerViewer tab={tab} />;
|
||||
} else if (tab.type === 'view-def' || tab.type === 'routine-def') {
|
||||
content = <DefinitionViewer tab={tab} />;
|
||||
}
|
||||
|
||||
const menuItems: MenuProps['items'] = [
|
||||
@@ -73,13 +101,13 @@ const TabManager: React.FC = () => {
|
||||
return {
|
||||
label: (
|
||||
<Dropdown menu={{ items: menuItems }} trigger={['contextMenu']}>
|
||||
<span onContextMenu={(e) => e.preventDefault()}>{tab.title}</span>
|
||||
<span onContextMenu={(e) => e.preventDefault()}>{displayTitle}</span>
|
||||
</Dropdown>
|
||||
),
|
||||
key: tab.id,
|
||||
children: content,
|
||||
};
|
||||
}), [tabs, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
|
||||
}), [tabs, connections, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import React, { useEffect, useState, useContext, useMemo, useRef } from 'react';
|
||||
import { Table, Tabs, Button, message, Input, Checkbox, Modal, AutoComplete, Tooltip, Select } from 'antd';
|
||||
import { ReloadOutlined, SaveOutlined, PlusOutlined, DeleteOutlined, MenuOutlined, FileTextOutlined } from '@ant-design/icons';
|
||||
import React, { useEffect, useState, useContext, useMemo, useRef, useCallback } from 'react';
|
||||
import { Table, Tabs, Button, message, Input, Checkbox, Modal, AutoComplete, Tooltip, Select, Empty, Space } from 'antd';
|
||||
import { ReloadOutlined, SaveOutlined, PlusOutlined, DeleteOutlined, MenuOutlined, FileTextOutlined, EyeOutlined, EditOutlined, ExclamationCircleOutlined } from '@ant-design/icons';
|
||||
import { DndContext, closestCenter, KeyboardSensor, PointerSensor, useSensor, useSensors, DragOverlay } from '@dnd-kit/core';
|
||||
import { arrayMove, SortableContext, sortableKeyboardCoordinates, verticalListSortingStrategy, useSortable } from '@dnd-kit/sortable';
|
||||
import { CSS } from '@dnd-kit/utilities';
|
||||
import { Resizable } from 'react-resizable';
|
||||
import Editor, { loader } from '@monaco-editor/react';
|
||||
import { TabData, ColumnDefinition, IndexDefinition, ForeignKeyDefinition, TriggerDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBGetColumns, DBGetIndexes, DBQuery, DBGetForeignKeys, DBGetTriggers, DBShowCreateTable } from '../../wailsjs/go/app/App';
|
||||
|
||||
// Need styles for react-resizable
|
||||
import 'react-resizable/css/styles.css';
|
||||
|
||||
interface EditableColumn extends ColumnDefinition {
|
||||
_key: string;
|
||||
isNew?: boolean;
|
||||
@@ -57,45 +54,43 @@ const COLLATIONS = {
|
||||
]
|
||||
};
|
||||
|
||||
// --- Resizable Header Component ---
|
||||
// --- Resizable Header Component (Native, same interaction as DataGrid) ---
|
||||
const ResizableTitle = (props: any) => {
|
||||
const { onResize, width, ...restProps } = props;
|
||||
const { onResizeStart, width, ...restProps } = props;
|
||||
const nextStyle = { ...(restProps.style || {}) } as React.CSSProperties;
|
||||
|
||||
if (width) {
|
||||
nextStyle.width = width;
|
||||
}
|
||||
|
||||
if (!width) {
|
||||
return <th {...restProps} />;
|
||||
return <th {...restProps} style={nextStyle} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Resizable
|
||||
width={width}
|
||||
height={0}
|
||||
handle={
|
||||
<span
|
||||
className="react-resizable-handle"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
}}
|
||||
onMouseDown={(e) => {
|
||||
e.stopPropagation();
|
||||
e.preventDefault(); // Prevent text selection and focus hijacking
|
||||
}}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
right: -5,
|
||||
bottom: 0,
|
||||
top: 0,
|
||||
width: 10,
|
||||
cursor: 'col-resize',
|
||||
zIndex: 10
|
||||
}}
|
||||
/>
|
||||
}
|
||||
onResize={onResize}
|
||||
draggableOpts={{ enableUserSelectHack: true }}
|
||||
>
|
||||
<th {...restProps} style={{ ...restProps.style, position: 'relative' }} />
|
||||
</Resizable>
|
||||
<th {...restProps} style={{ ...nextStyle, position: 'relative' }}>
|
||||
{restProps.children}
|
||||
<span
|
||||
className="react-resizable-handle"
|
||||
onMouseDown={(e) => {
|
||||
e.stopPropagation();
|
||||
if (typeof onResizeStart === 'function') {
|
||||
onResizeStart(e);
|
||||
}
|
||||
}}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
top: 0,
|
||||
width: 10,
|
||||
cursor: 'col-resize',
|
||||
zIndex: 10,
|
||||
touchAction: 'none',
|
||||
}}
|
||||
/>
|
||||
</th>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -162,13 +157,47 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [previewSql, setPreviewSql] = useState<string>('');
|
||||
const [isPreviewOpen, setIsPreviewOpen] = useState(false);
|
||||
const [activeKey, setActiveKey] = useState(tab.initialTab || "columns");
|
||||
const [selectedTrigger, setSelectedTrigger] = useState<TriggerDefinition | null>(null);
|
||||
const [isTriggerModalOpen, setIsTriggerModalOpen] = useState(false);
|
||||
const [isTriggerEditModalOpen, setIsTriggerEditModalOpen] = useState(false);
|
||||
const [triggerEditMode, setTriggerEditMode] = useState<'create' | 'edit'>('create');
|
||||
const [triggerEditSql, setTriggerEditSql] = useState<string>('');
|
||||
const [triggerExecuting, setTriggerExecuting] = useState(false);
|
||||
|
||||
const connections = useStore(state => state.connections);
|
||||
const theme = useStore(state => state.theme);
|
||||
const darkMode = theme === 'dark';
|
||||
const readOnly = !!tab.readOnly;
|
||||
|
||||
const [tableHeight, setTableHeight] = useState(500);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// 初始化透明 Monaco Editor 主题
|
||||
useEffect(() => {
|
||||
loader.init().then(monaco => {
|
||||
monaco.editor.defineTheme('transparent-dark', {
|
||||
base: 'vs-dark',
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
'editor.background': '#00000000',
|
||||
'editor.lineHighlightBackground': '#ffffff10',
|
||||
'editorGutter.background': '#00000000',
|
||||
}
|
||||
});
|
||||
monaco.editor.defineTheme('transparent-light', {
|
||||
base: 'vs',
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
'editor.background': '#00000000',
|
||||
'editor.lineHighlightBackground': '#00000010',
|
||||
'editorGutter.background': '#00000000',
|
||||
}
|
||||
});
|
||||
});
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) return;
|
||||
const resizeObserver = new ResizeObserver(entries => {
|
||||
@@ -183,6 +212,14 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
// --- Resizable Columns State ---
|
||||
const [tableColumns, setTableColumns] = useState<any[]>([]);
|
||||
const resizeDragRef = useRef<{ startX: number; startWidth: number; index: number; containerLeft: number } | null>(null);
|
||||
const resizeRafRef = useRef<number | null>(null);
|
||||
const latestResizeXRef = useRef<number | null>(null);
|
||||
const ghostRef = useRef<HTMLDivElement>(null);
|
||||
const resizeListenerRef = useRef<{ move: ((e: MouseEvent) => void) | null; up: ((e: MouseEvent) => void) | null }>({
|
||||
move: null,
|
||||
up: null,
|
||||
});
|
||||
|
||||
const sensors = useSensors(
|
||||
useSensor(PointerSensor),
|
||||
@@ -283,25 +320,97 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
setTableColumns(initialCols);
|
||||
}, [readOnly]); // Re-create if readOnly changes
|
||||
|
||||
const rafRef = React.useRef<number | null>(null);
|
||||
const flushResizeGhost = useCallback(() => {
|
||||
resizeRafRef.current = null;
|
||||
if (!resizeDragRef.current || !ghostRef.current) return;
|
||||
if (latestResizeXRef.current === null) return;
|
||||
const relativeLeft = latestResizeXRef.current - resizeDragRef.current.containerLeft;
|
||||
ghostRef.current.style.transform = `translateX(${relativeLeft}px)`;
|
||||
}, []);
|
||||
|
||||
// Resize Handler
|
||||
const handleResize = (index: number) => (_: React.SyntheticEvent, { size }: { size: { width: number } }) => {
|
||||
if (rafRef.current) {
|
||||
cancelAnimationFrame(rafRef.current);
|
||||
}
|
||||
rafRef.current = requestAnimationFrame(() => {
|
||||
setTableColumns((columns) => {
|
||||
const nextColumns = [...columns];
|
||||
nextColumns[index] = {
|
||||
...nextColumns[index],
|
||||
width: size.width,
|
||||
};
|
||||
return nextColumns;
|
||||
const detachResizeListeners = useCallback(() => {
|
||||
if (resizeListenerRef.current.move) {
|
||||
document.removeEventListener('mousemove', resizeListenerRef.current.move);
|
||||
resizeListenerRef.current.move = null;
|
||||
}
|
||||
if (resizeListenerRef.current.up) {
|
||||
document.removeEventListener('mouseup', resizeListenerRef.current.up);
|
||||
resizeListenerRef.current.up = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const cleanupResizeState = useCallback(() => {
|
||||
if (resizeRafRef.current !== null) {
|
||||
cancelAnimationFrame(resizeRafRef.current);
|
||||
resizeRafRef.current = null;
|
||||
}
|
||||
latestResizeXRef.current = null;
|
||||
resizeDragRef.current = null;
|
||||
if (ghostRef.current) {
|
||||
ghostRef.current.style.display = 'none';
|
||||
}
|
||||
document.body.style.cursor = '';
|
||||
document.body.style.userSelect = '';
|
||||
}, []);
|
||||
|
||||
const handleResizeStart = useCallback((index: number) => (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
|
||||
const startX = e.clientX;
|
||||
const currentWidth = Number(tableColumns[index]?.width || 200);
|
||||
const containerLeft = containerRef.current?.getBoundingClientRect().left ?? 0;
|
||||
resizeDragRef.current = { startX, startWidth: currentWidth, index, containerLeft };
|
||||
latestResizeXRef.current = startX;
|
||||
|
||||
if (ghostRef.current && containerRef.current) {
|
||||
const relativeLeft = startX - containerLeft;
|
||||
ghostRef.current.style.transform = `translateX(${relativeLeft}px)`;
|
||||
ghostRef.current.style.display = 'block';
|
||||
}
|
||||
|
||||
detachResizeListeners();
|
||||
|
||||
const onMove = (event: MouseEvent) => {
|
||||
if (!resizeDragRef.current) return;
|
||||
latestResizeXRef.current = event.clientX;
|
||||
if (resizeRafRef.current !== null) return;
|
||||
resizeRafRef.current = requestAnimationFrame(flushResizeGhost);
|
||||
};
|
||||
|
||||
const onUp = (event: MouseEvent) => {
|
||||
if (resizeDragRef.current) {
|
||||
const { startX: dragStartX, startWidth, index: dragIndex } = resizeDragRef.current;
|
||||
const deltaX = event.clientX - dragStartX;
|
||||
const newWidth = Math.max(50, startWidth + deltaX);
|
||||
setTableColumns((prevColumns) => {
|
||||
if (!prevColumns[dragIndex]) return prevColumns;
|
||||
const nextColumns = [...prevColumns];
|
||||
nextColumns[dragIndex] = {
|
||||
...nextColumns[dragIndex],
|
||||
width: newWidth,
|
||||
};
|
||||
return nextColumns;
|
||||
});
|
||||
rafRef.current = null;
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
detachResizeListeners();
|
||||
cleanupResizeState();
|
||||
};
|
||||
|
||||
resizeListenerRef.current = { move: onMove, up: onUp };
|
||||
document.addEventListener('mousemove', onMove);
|
||||
document.addEventListener('mouseup', onUp);
|
||||
document.body.style.cursor = 'col-resize';
|
||||
document.body.style.userSelect = 'none';
|
||||
}, [cleanupResizeState, detachResizeListeners, flushResizeGhost, tableColumns]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
detachResizeListeners();
|
||||
cleanupResizeState();
|
||||
};
|
||||
}, [cleanupResizeState, detachResizeListeners]);
|
||||
|
||||
const fetchData = async () => {
|
||||
if (isNewTable) return; // Don't fetch for new table
|
||||
@@ -365,6 +474,215 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
fetchData();
|
||||
}, [tab]);
|
||||
|
||||
// --- Trigger Handlers ---
|
||||
|
||||
const getDbType = (): string => {
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
const type = String(conn?.config?.type || '').toLowerCase();
|
||||
if (type === 'mariadb' || type === 'sphinx') return 'mysql';
|
||||
if (type === 'dameng') return 'dm';
|
||||
return type;
|
||||
};
|
||||
|
||||
const generateTriggerTemplate = (): string => {
|
||||
const dbType = getDbType();
|
||||
const tblName = tab.tableName || 'table_name';
|
||||
|
||||
switch (dbType) {
|
||||
case 'mysql':
|
||||
return `CREATE TRIGGER trigger_name
|
||||
BEFORE INSERT ON \`${tblName}\`
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
-- 触发器逻辑
|
||||
END;`;
|
||||
case 'postgres':
|
||||
case 'kingbase':
|
||||
case 'highgo':
|
||||
case 'vastbase':
|
||||
return `CREATE OR REPLACE FUNCTION trigger_function_name()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
-- 触发器逻辑
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_name
|
||||
BEFORE INSERT ON "${tblName}"
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION trigger_function_name();`;
|
||||
case 'sqlserver':
|
||||
return `CREATE TRIGGER trigger_name
|
||||
ON [${tblName}]
|
||||
AFTER INSERT
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON;
|
||||
-- 触发器逻辑
|
||||
END;`;
|
||||
case 'oracle':
|
||||
case 'dm':
|
||||
return `CREATE OR REPLACE TRIGGER trigger_name
|
||||
BEFORE INSERT ON "${tblName}"
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
-- 触发器逻辑
|
||||
NULL;
|
||||
END;`;
|
||||
case 'sqlite':
|
||||
return `CREATE TRIGGER trigger_name
|
||||
AFTER INSERT ON "${tblName}"
|
||||
BEGIN
|
||||
-- 触发器逻辑
|
||||
END;`;
|
||||
default:
|
||||
return `-- 请输入 CREATE TRIGGER 语句`;
|
||||
}
|
||||
};
|
||||
|
||||
const buildDropTriggerSql = (triggerName: string): string => {
|
||||
const dbType = getDbType();
|
||||
const tblName = tab.tableName || '';
|
||||
|
||||
switch (dbType) {
|
||||
case 'mysql':
|
||||
return `DROP TRIGGER IF EXISTS \`${triggerName}\``;
|
||||
case 'postgres':
|
||||
case 'kingbase':
|
||||
case 'highgo':
|
||||
case 'vastbase':
|
||||
return `DROP TRIGGER IF EXISTS "${triggerName}" ON "${tblName}"`;
|
||||
case 'sqlserver':
|
||||
return `DROP TRIGGER IF EXISTS [${triggerName}]`;
|
||||
case 'oracle':
|
||||
case 'dm':
|
||||
return `DROP TRIGGER "${triggerName}"`;
|
||||
case 'sqlite':
|
||||
return `DROP TRIGGER IF EXISTS "${triggerName}"`;
|
||||
default:
|
||||
return `DROP TRIGGER ${triggerName}`;
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreateTrigger = () => {
|
||||
setTriggerEditMode('create');
|
||||
setTriggerEditSql(generateTriggerTemplate());
|
||||
setIsTriggerEditModalOpen(true);
|
||||
};
|
||||
|
||||
const handleEditTrigger = () => {
|
||||
if (!selectedTrigger) return;
|
||||
setTriggerEditMode('edit');
|
||||
// 构建完整的 CREATE TRIGGER 语句
|
||||
const dbType = getDbType();
|
||||
const tblName = tab.tableName || '';
|
||||
let createSql = '';
|
||||
|
||||
if (dbType === 'mysql') {
|
||||
createSql = `CREATE TRIGGER \`${selectedTrigger.name}\`
|
||||
${selectedTrigger.timing} ${selectedTrigger.event} ON \`${tblName}\`
|
||||
FOR EACH ROW
|
||||
${selectedTrigger.statement}`;
|
||||
} else {
|
||||
createSql = selectedTrigger.statement || '-- 无法获取完整的触发器定义';
|
||||
}
|
||||
|
||||
setTriggerEditSql(createSql);
|
||||
setIsTriggerEditModalOpen(true);
|
||||
};
|
||||
|
||||
const handleDeleteTrigger = () => {
|
||||
if (!selectedTrigger) return;
|
||||
|
||||
Modal.confirm({
|
||||
title: '确认删除触发器',
|
||||
icon: <ExclamationCircleOutlined />,
|
||||
content: `确定要删除触发器 "${selectedTrigger.name}" 吗?此操作不可撤销。`,
|
||||
okText: '删除',
|
||||
okType: 'danger',
|
||||
cancelText: '取消',
|
||||
onOk: async () => {
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) {
|
||||
message.error('未找到连接');
|
||||
return;
|
||||
}
|
||||
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: conn.config.database || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const dropSql = buildDropTriggerSql(selectedTrigger.name);
|
||||
|
||||
try {
|
||||
const res = await DBQuery(config as any, tab.dbName || '', dropSql);
|
||||
if (res.success) {
|
||||
message.success('触发器删除成功');
|
||||
setSelectedTrigger(null);
|
||||
fetchData(); // 刷新列表
|
||||
} else {
|
||||
message.error('删除失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
message.error('删除失败: ' + (e?.message || String(e)));
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const handleExecuteTriggerSql = async () => {
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) {
|
||||
message.error('未找到连接');
|
||||
return;
|
||||
}
|
||||
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: conn.config.database || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
setTriggerExecuting(true);
|
||||
|
||||
try {
|
||||
// 如果是编辑模式,先删除旧触发器
|
||||
if (triggerEditMode === 'edit' && selectedTrigger) {
|
||||
const dropSql = buildDropTriggerSql(selectedTrigger.name);
|
||||
const dropRes = await DBQuery(config as any, tab.dbName || '', dropSql);
|
||||
if (!dropRes.success) {
|
||||
message.error('删除旧触发器失败: ' + dropRes.message);
|
||||
setTriggerExecuting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 执行创建语句
|
||||
const res = await DBQuery(config as any, tab.dbName || '', triggerEditSql);
|
||||
if (res.success) {
|
||||
message.success(triggerEditMode === 'create' ? '触发器创建成功' : '触发器修改成功');
|
||||
setIsTriggerEditModalOpen(false);
|
||||
setSelectedTrigger(null);
|
||||
fetchData(); // 刷新列表
|
||||
} else {
|
||||
message.error('执行失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
message.error('执行失败: ' + (e?.message || String(e)));
|
||||
} finally {
|
||||
setTriggerExecuting(false);
|
||||
}
|
||||
};
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
const handleColumnChange = (key: string, field: keyof EditableColumn, value: any) => {
|
||||
@@ -542,7 +860,7 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
...col,
|
||||
onHeaderCell: (column: any) => ({
|
||||
width: column.width,
|
||||
onResize: handleResize(index),
|
||||
onResizeStart: handleResizeStart(index),
|
||||
}),
|
||||
}));
|
||||
|
||||
@@ -589,6 +907,21 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
</SortableContext>
|
||||
</DndContext>
|
||||
)}
|
||||
<div
|
||||
ref={ghostRef}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
bottom: 0,
|
||||
left: 0,
|
||||
width: '2px',
|
||||
background: '#1890ff',
|
||||
zIndex: 9999,
|
||||
display: 'none',
|
||||
pointerEvents: 'none',
|
||||
willChange: 'transform',
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -680,19 +1013,61 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
key: 'triggers',
|
||||
label: '触发器',
|
||||
children: (
|
||||
<Table
|
||||
dataSource={triggers}
|
||||
columns={[
|
||||
{ title: '名', dataIndex: 'name', key: 'name' },
|
||||
{ title: '时间', dataIndex: 'timing', key: 'timing' },
|
||||
{ title: '事件', dataIndex: 'event', key: 'event' },
|
||||
{ title: '语句', dataIndex: 'statement', key: 'statement', ellipsis: true },
|
||||
]}
|
||||
rowKey="name"
|
||||
size="small"
|
||||
pagination={false}
|
||||
loading={loading}
|
||||
/>
|
||||
<div>
|
||||
<div style={{ marginBottom: 8, display: 'flex', gap: 8 }}>
|
||||
<Button
|
||||
size="small"
|
||||
icon={<EyeOutlined />}
|
||||
disabled={!selectedTrigger}
|
||||
onClick={() => setIsTriggerModalOpen(true)}
|
||||
>
|
||||
查看语句
|
||||
</Button>
|
||||
<Button size="small" icon={<PlusOutlined />} onClick={handleCreateTrigger}>新增</Button>
|
||||
<Button size="small" icon={<EditOutlined />} disabled={!selectedTrigger} onClick={handleEditTrigger}>修改</Button>
|
||||
<Button size="small" icon={<DeleteOutlined />} danger disabled={!selectedTrigger} onClick={handleDeleteTrigger}>删除</Button>
|
||||
<span style={{ marginLeft: 'auto', color: '#888', fontSize: 12, alignSelf: 'center' }}>
|
||||
{selectedTrigger ? `已选择: ${selectedTrigger.name}` : '请点击选择触发器'}
|
||||
</span>
|
||||
</div>
|
||||
<Table
|
||||
dataSource={triggers}
|
||||
columns={[
|
||||
{ title: '名称', dataIndex: 'name', key: 'name' },
|
||||
{ title: '时机', dataIndex: 'timing', key: 'timing', width: 100 },
|
||||
{ title: '事件', dataIndex: 'event', key: 'event', width: 100 },
|
||||
]}
|
||||
rowKey="name"
|
||||
size="small"
|
||||
pagination={false}
|
||||
loading={loading}
|
||||
locale={{ emptyText: <Empty description="该表暂无触发器" image={Empty.PRESENTED_IMAGE_SIMPLE} /> }}
|
||||
rowSelection={{
|
||||
type: 'radio',
|
||||
selectedRowKeys: selectedTrigger ? [selectedTrigger.name] : [],
|
||||
onChange: (_, selectedRows) => setSelectedTrigger(selectedRows[0] || null),
|
||||
onSelect: (record, selected) => {
|
||||
// 点击单选按钮时,如果已选中则取消
|
||||
if (selectedTrigger?.name === record.name) {
|
||||
setSelectedTrigger(null);
|
||||
} else {
|
||||
setSelectedTrigger(record);
|
||||
}
|
||||
},
|
||||
}}
|
||||
onRow={(record) => ({
|
||||
onClick: () => {
|
||||
// 点击已选中的行时取消选择
|
||||
if (selectedTrigger?.name === record.name) {
|
||||
setSelectedTrigger(null);
|
||||
} else {
|
||||
setSelectedTrigger(record);
|
||||
}
|
||||
},
|
||||
style: { cursor: 'pointer' }
|
||||
})}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
] : []),
|
||||
@@ -701,8 +1076,22 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
label: 'DDL',
|
||||
icon: <FileTextOutlined />,
|
||||
children: (
|
||||
<div style={{ height: 'calc(100vh - 200px)', overflow: 'auto', padding: 10, background: '#f5f5f5', border: '1px solid #eee' }}>
|
||||
<pre>{ddl}</pre>
|
||||
<div style={{ height: 'calc(100vh - 200px)', border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
|
||||
<Editor
|
||||
height="100%"
|
||||
language="sql"
|
||||
theme={darkMode ? 'transparent-dark' : 'transparent-light'}
|
||||
value={ddl}
|
||||
options={{
|
||||
readOnly: true,
|
||||
minimap: { enabled: false },
|
||||
fontSize: 14,
|
||||
lineNumbers: 'on',
|
||||
scrollBeyondLastLine: false,
|
||||
wordWrap: 'on',
|
||||
automaticLayout: true,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}] : [])
|
||||
@@ -725,6 +1114,75 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
</div>
|
||||
<p style={{ marginTop: 10, color: '#faad14' }}>请仔细检查 SQL,执行后不可撤销。</p>
|
||||
</Modal>
|
||||
|
||||
<Modal
|
||||
title={selectedTrigger ? `触发器: ${selectedTrigger.name}` : '触发器详情'}
|
||||
open={isTriggerModalOpen}
|
||||
onCancel={() => setIsTriggerModalOpen(false)}
|
||||
footer={null}
|
||||
width={700}
|
||||
>
|
||||
{selectedTrigger && (
|
||||
<div>
|
||||
<div style={{ marginBottom: 12, display: 'flex', gap: 24 }}>
|
||||
<span><strong>时机:</strong> {selectedTrigger.timing}</span>
|
||||
<span><strong>事件:</strong> {selectedTrigger.event}</span>
|
||||
</div>
|
||||
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
|
||||
<Editor
|
||||
height="350px"
|
||||
language="sql"
|
||||
theme={darkMode ? 'transparent-dark' : 'transparent-light'}
|
||||
value={selectedTrigger.statement}
|
||||
options={{
|
||||
readOnly: true,
|
||||
minimap: { enabled: false },
|
||||
fontSize: 14,
|
||||
lineNumbers: 'on',
|
||||
scrollBeyondLastLine: false,
|
||||
wordWrap: 'on',
|
||||
automaticLayout: true,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
|
||||
<Modal
|
||||
title={triggerEditMode === 'create' ? '新增触发器' : '修改触发器'}
|
||||
open={isTriggerEditModalOpen}
|
||||
onCancel={() => setIsTriggerEditModalOpen(false)}
|
||||
width={800}
|
||||
okText={triggerEditMode === 'create' ? '创建' : '保存'}
|
||||
cancelText="取消"
|
||||
confirmLoading={triggerExecuting}
|
||||
onOk={handleExecuteTriggerSql}
|
||||
>
|
||||
<div style={{ marginBottom: 8, color: '#888', fontSize: 12 }}>
|
||||
{triggerEditMode === 'edit' && selectedTrigger && (
|
||||
<span>修改触发器时会先删除原触发器,再创建新触发器。</span>
|
||||
)}
|
||||
</div>
|
||||
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
|
||||
<Editor
|
||||
height="350px"
|
||||
language="sql"
|
||||
theme={darkMode ? 'vs-dark' : 'light'}
|
||||
value={triggerEditSql}
|
||||
onChange={(val) => setTriggerEditSql(val || '')}
|
||||
options={{
|
||||
minimap: { enabled: false },
|
||||
fontSize: 14,
|
||||
lineNumbers: 'on',
|
||||
scrollBeyondLastLine: false,
|
||||
wordWrap: 'on',
|
||||
automaticLayout: true,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<p style={{ marginTop: 10, color: '#faad14' }}>请仔细检查 SQL 语句,执行后不可撤销。</p>
|
||||
</Modal>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
336
frontend/src/components/TriggerViewer.tsx
Normal file
336
frontend/src/components/TriggerViewer.tsx
Normal file
@@ -0,0 +1,336 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import Editor, { loader } from '@monaco-editor/react';
|
||||
import { Spin, Alert } from 'antd';
|
||||
import { TabData } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery } from '../../wailsjs/go/app/App';
|
||||
|
||||
interface TriggerViewerProps {
|
||||
tab: TabData;
|
||||
}
|
||||
|
||||
const TriggerViewer: React.FC<TriggerViewerProps> = ({ tab }) => {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [triggerDefinition, setTriggerDefinition] = useState<string>('');
|
||||
|
||||
const connections = useStore(state => state.connections);
|
||||
const theme = useStore(state => state.theme);
|
||||
const darkMode = theme === 'dark';
|
||||
|
||||
// 初始化透明 Monaco Editor 主题
|
||||
useEffect(() => {
|
||||
loader.init().then(monaco => {
|
||||
monaco.editor.defineTheme('transparent-dark', {
|
||||
base: 'vs-dark',
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
'editor.background': '#00000000',
|
||||
'editor.lineHighlightBackground': '#ffffff10',
|
||||
'editorGutter.background': '#00000000',
|
||||
}
|
||||
});
|
||||
monaco.editor.defineTheme('transparent-light', {
|
||||
base: 'vs',
|
||||
inherit: true,
|
||||
rules: [],
|
||||
colors: {
|
||||
'editor.background': '#00000000',
|
||||
'editor.lineHighlightBackground': '#00000010',
|
||||
'editorGutter.background': '#00000000',
|
||||
}
|
||||
});
|
||||
});
|
||||
}, []);
|
||||
|
||||
const escapeSQLLiteral = (raw: string): string => String(raw || '').replace(/'/g, "''");
|
||||
const quoteSqlServerIdentifier = (raw: string): string => `[${String(raw || '').replace(/]/g, ']]')}]`;
|
||||
|
||||
const getMetadataDialect = (conn: any): string => {
|
||||
const type = String(conn?.config?.type || '').trim().toLowerCase();
|
||||
if (type === 'custom') {
|
||||
return String(conn?.config?.driver || '').trim().toLowerCase();
|
||||
}
|
||||
if (type === 'mariadb' || type === 'sphinx') return 'mysql';
|
||||
if (type === 'dameng') return 'dm';
|
||||
return type;
|
||||
};
|
||||
|
||||
const isSphinxConnection = (conn: any): boolean => {
|
||||
const type = String(conn?.config?.type || '').trim().toLowerCase();
|
||||
if (type === 'sphinx') return true;
|
||||
if (type !== 'custom') return false;
|
||||
const driver = String(conn?.config?.driver || '').trim().toLowerCase();
|
||||
return driver === 'sphinx' || driver === 'sphinxql';
|
||||
};
|
||||
|
||||
const buildShowTriggerQueries = (dialect: string, triggerName: string, dbName: string): string[] => {
|
||||
const safeTriggerName = escapeSQLLiteral(triggerName);
|
||||
const safeDbName = escapeSQLLiteral(dbName);
|
||||
switch (dialect) {
|
||||
case 'mysql':
|
||||
return [
|
||||
`SHOW CREATE TRIGGER \`${triggerName.replace(/`/g, '``')}\``,
|
||||
safeDbName
|
||||
? `SELECT ACTION_STATEMENT AS trigger_definition FROM information_schema.triggers WHERE trigger_schema = '${safeDbName}' AND trigger_name = '${safeTriggerName}' LIMIT 1`
|
||||
: '',
|
||||
safeDbName
|
||||
? `SHOW TRIGGERS FROM \`${dbName.replace(/`/g, '``')}\` LIKE '${safeTriggerName}'`
|
||||
: `SHOW TRIGGERS LIKE '${safeTriggerName}'`,
|
||||
].filter(Boolean);
|
||||
case 'postgres':
|
||||
case 'kingbase':
|
||||
case 'highgo':
|
||||
case 'vastbase':
|
||||
return [`SELECT pg_get_triggerdef(t.oid, true) AS trigger_definition
|
||||
FROM pg_trigger t
|
||||
JOIN pg_class c ON t.tgrelid = c.oid
|
||||
WHERE t.tgname = '${safeTriggerName}'
|
||||
AND NOT t.tgisinternal
|
||||
LIMIT 1`];
|
||||
case 'sqlserver': {
|
||||
return [`SELECT OBJECT_DEFINITION(OBJECT_ID('${safeTriggerName.replace(/'/g, "''")}')) AS trigger_definition`];
|
||||
}
|
||||
case 'oracle':
|
||||
case 'dm':
|
||||
if (!safeDbName) {
|
||||
return [`SELECT TRIGGER_BODY FROM USER_TRIGGERS WHERE TRIGGER_NAME = '${safeTriggerName.toUpperCase()}'`];
|
||||
}
|
||||
return [`SELECT TRIGGER_BODY FROM ALL_TRIGGERS WHERE OWNER = '${safeDbName.toUpperCase()}' AND TRIGGER_NAME = '${safeTriggerName.toUpperCase()}'`];
|
||||
case 'sqlite':
|
||||
return [`SELECT sql FROM sqlite_master WHERE type = 'trigger' AND name = '${safeTriggerName}'`];
|
||||
case 'tdengine':
|
||||
return [`-- TDengine 不支持触发器`];
|
||||
case 'mongodb':
|
||||
return [`-- MongoDB 不支持触发器`];
|
||||
default:
|
||||
return [`-- 暂不支持该数据库类型的触发器定义查看`];
|
||||
}
|
||||
};
|
||||
|
||||
const runQueryCandidates = async (
|
||||
config: Record<string, any>,
|
||||
dbName: string,
|
||||
queries: string[]
|
||||
): Promise<{ success: boolean; data: any[]; message?: string }> => {
|
||||
let lastMessage = '';
|
||||
let hasSuccessfulQuery = false;
|
||||
for (const query of queries) {
|
||||
const sql = String(query || '').trim();
|
||||
if (!sql) continue;
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, sql);
|
||||
if (!result.success || !Array.isArray(result.data)) {
|
||||
lastMessage = result.message || lastMessage;
|
||||
continue;
|
||||
}
|
||||
hasSuccessfulQuery = true;
|
||||
if (result.data.length > 0) {
|
||||
return { success: true, data: result.data };
|
||||
}
|
||||
} catch (error: any) {
|
||||
lastMessage = error?.message || String(error);
|
||||
}
|
||||
}
|
||||
if (hasSuccessfulQuery) {
|
||||
return { success: true, data: [] };
|
||||
}
|
||||
return { success: false, data: [], message: lastMessage };
|
||||
};
|
||||
|
||||
const getVersionHint = async (config: Record<string, any>, dbName: string): Promise<string> => {
|
||||
const candidates = [
|
||||
`SELECT VERSION() AS version`,
|
||||
`SHOW VARIABLES LIKE 'version'`,
|
||||
];
|
||||
for (const query of candidates) {
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, query);
|
||||
if (!result.success || !Array.isArray(result.data) || result.data.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const row = result.data[0] as Record<string, any>;
|
||||
const version =
|
||||
row.version
|
||||
|| row.VERSION
|
||||
|| row.Value
|
||||
|| row.value
|
||||
|| Object.values(row)[1]
|
||||
|| Object.values(row)[0];
|
||||
const text = String(version || '').trim();
|
||||
if (text) return text;
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const extractTriggerDefinition = (dialect: string, data: any[]): string => {
|
||||
if (!data || data.length === 0) {
|
||||
return '-- 未找到触发器定义';
|
||||
}
|
||||
|
||||
const row = data[0];
|
||||
|
||||
switch (dialect) {
|
||||
case 'mysql': {
|
||||
// MySQL SHOW CREATE TRIGGER returns: Trigger, sql_mode, SQL Original Statement, ...
|
||||
const keys = Object.keys(row);
|
||||
if (row.trigger_definition || row.TRIGGER_DEFINITION) {
|
||||
return String(row.trigger_definition || row.TRIGGER_DEFINITION);
|
||||
}
|
||||
if (row.ACTION_STATEMENT || row.action_statement) {
|
||||
return String(row.ACTION_STATEMENT || row.action_statement);
|
||||
}
|
||||
const sqlKey = keys.find(k => k.toLowerCase().includes('statement') || k.toLowerCase() === 'sql original statement');
|
||||
if (sqlKey) return row[sqlKey];
|
||||
// Fallback: try to find any key containing CREATE TRIGGER
|
||||
for (const key of keys) {
|
||||
const val = String(row[key] || '');
|
||||
if (val.toUpperCase().includes('CREATE TRIGGER')) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
return JSON.stringify(row, null, 2);
|
||||
}
|
||||
case 'postgres':
|
||||
case 'kingbase':
|
||||
case 'highgo':
|
||||
case 'vastbase': {
|
||||
return row.trigger_definition || row.TRIGGER_DEFINITION || Object.values(row)[0] || '';
|
||||
}
|
||||
case 'sqlserver': {
|
||||
return row.trigger_definition || row.TRIGGER_DEFINITION || Object.values(row)[0] || '';
|
||||
}
|
||||
case 'oracle':
|
||||
case 'dm': {
|
||||
return row.trigger_body || row.TRIGGER_BODY || Object.values(row)[0] || '';
|
||||
}
|
||||
case 'sqlite': {
|
||||
return row.sql || row.SQL || Object.values(row)[0] || '';
|
||||
}
|
||||
default:
|
||||
return JSON.stringify(row, null, 2);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const loadTriggerDefinition = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) {
|
||||
setError('未找到数据库连接');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const triggerName = tab.triggerName || '';
|
||||
const dbName = tab.dbName || '';
|
||||
|
||||
if (!triggerName) {
|
||||
setError('触发器名称为空');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const dialect = getMetadataDialect(conn);
|
||||
const queries = buildShowTriggerQueries(dialect, triggerName, dbName);
|
||||
const sphinxLike = isSphinxConnection(conn) && dialect === 'mysql';
|
||||
|
||||
if (!queries.length || String(queries[0] || '').startsWith('--')) {
|
||||
setTriggerDefinition(String(queries[0] || '-- 暂不支持该数据库类型的触发器定义查看'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || '',
|
||||
database: conn.config.database || '',
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' }
|
||||
};
|
||||
|
||||
const result = await runQueryCandidates(config, dbName, queries);
|
||||
|
||||
if (result.success && Array.isArray(result.data) && result.data.length > 0) {
|
||||
const definition = extractTriggerDefinition(dialect, result.data);
|
||||
setTriggerDefinition(definition);
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.success) {
|
||||
if (sphinxLike) {
|
||||
const version = await getVersionHint(config, dbName);
|
||||
const versionText = version ? `(版本: ${version})` : '';
|
||||
setTriggerDefinition(`-- 当前 Sphinx 实例${versionText}未返回触发器定义。\n-- 已执行多套兼容查询,可能是版本能力限制或对象类型不支持。`);
|
||||
return;
|
||||
}
|
||||
setTriggerDefinition('-- 未找到触发器定义');
|
||||
} else if (sphinxLike) {
|
||||
const version = await getVersionHint(config, dbName);
|
||||
const versionText = version ? `(版本: ${version})` : '';
|
||||
setTriggerDefinition(`-- 当前 Sphinx 实例${versionText}不支持触发器定义查询。\n-- 已自动尝试兼容语句,返回失败信息: ${result.message || 'unknown error'}`);
|
||||
} else {
|
||||
setError(result.message || '查询触发器定义失败');
|
||||
}
|
||||
} catch (e: any) {
|
||||
setError('查询触发器定义失败: ' + (e?.message || String(e)));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadTriggerDefinition();
|
||||
}, [tab.connectionId, tab.dbName, tab.triggerName, connections]);
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', alignItems: 'center', height: '100%' }}>
|
||||
<Spin tip="加载触发器定义..." />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div style={{ padding: 16 }}>
|
||||
<Alert type="error" message="加载失败" description={error} showIcon />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
<div style={{ padding: '8px 16px', borderBottom: darkMode ? '1px solid #303030' : '1px solid #f0f0f0' }}>
|
||||
<strong>触发器: </strong>{tab.triggerName}
|
||||
{tab.dbName && <span style={{ marginLeft: 16, color: '#888' }}>数据库: {tab.dbName}</span>}
|
||||
</div>
|
||||
<div style={{ flex: 1, minHeight: 0 }}>
|
||||
<Editor
|
||||
height="100%"
|
||||
language="sql"
|
||||
theme={darkMode ? 'transparent-dark' : 'transparent-light'}
|
||||
value={triggerDefinition}
|
||||
options={{
|
||||
readOnly: true,
|
||||
minimap: { enabled: false },
|
||||
fontSize: 14,
|
||||
lineNumbers: 'on',
|
||||
scrollBeyondLastLine: false,
|
||||
wordWrap: 'on',
|
||||
automaticLayout: true,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TriggerViewer;
|
||||
@@ -3,6 +3,22 @@ import ReactDOM from 'react-dom/client'
|
||||
import App from './App'
|
||||
// import './index.css' // Optional global styles
|
||||
|
||||
// 全局配置 Monaco Editor 使用本地打包的文件,避免从 CDN (jsdelivr) 加载。
|
||||
// Windows WebView2 环境下访问外部 CDN 可能失败,导致编辑器一直显示 Loading。
|
||||
import { loader } from '@monaco-editor/react'
|
||||
import * as monaco from 'monaco-editor'
|
||||
loader.config({ monaco })
|
||||
|
||||
// 全局注册透明主题,避免每个 Editor 组件 beforeMount 中重复定义
|
||||
monaco.editor.defineTheme('transparent-dark', {
|
||||
base: 'vs-dark', inherit: true, rules: [],
|
||||
colors: { 'editor.background': '#00000000', 'editor.lineHighlightBackground': '#ffffff10', 'editorGutter.background': '#00000000' }
|
||||
})
|
||||
monaco.editor.defineTheme('transparent-light', {
|
||||
base: 'vs', inherit: true, rules: [],
|
||||
colors: { 'editor.background': '#00000000', 'editor.lineHighlightBackground': '#00000010', 'editorGutter.background': '#00000000' }
|
||||
})
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
|
||||
@@ -1,6 +1,239 @@
|
||||
import { create } from 'zustand';
|
||||
import { persist } from 'zustand/middleware';
|
||||
import { SavedConnection, TabData, SavedQuery } from './types';
|
||||
import { ConnectionConfig, SavedConnection, TabData, SavedQuery } from './types';
|
||||
|
||||
const DEFAULT_APPEARANCE = { opacity: 1.0, blur: 0 };
|
||||
const LEGACY_DEFAULT_OPACITY = 0.95;
|
||||
const OPACITY_EPSILON = 1e-6;
|
||||
const MAX_URI_LENGTH = 4096;
|
||||
const MAX_HOST_ENTRY_LENGTH = 512;
|
||||
const MAX_HOST_ENTRIES = 64;
|
||||
const DEFAULT_TIMEOUT_SECONDS = 30;
|
||||
const MAX_TIMEOUT_SECONDS = 3600;
|
||||
const DEFAULT_CONNECTION_TYPE = 'mysql';
|
||||
const SUPPORTED_CONNECTION_TYPES = new Set([
|
||||
'mysql',
|
||||
'mariadb',
|
||||
'sphinx',
|
||||
'postgres',
|
||||
'redis',
|
||||
'tdengine',
|
||||
'oracle',
|
||||
'dameng',
|
||||
'kingbase',
|
||||
'sqlserver',
|
||||
'mongodb',
|
||||
'highgo',
|
||||
'vastbase',
|
||||
'sqlite',
|
||||
'custom',
|
||||
]);
|
||||
|
||||
const getDefaultPortByType = (type: string): number => {
|
||||
switch (type) {
|
||||
case 'mysql':
|
||||
case 'mariadb':
|
||||
return 3306;
|
||||
case 'sphinx':
|
||||
return 9306;
|
||||
case 'postgres':
|
||||
case 'vastbase':
|
||||
return 5432;
|
||||
case 'redis':
|
||||
return 6379;
|
||||
case 'tdengine':
|
||||
return 6041;
|
||||
case 'oracle':
|
||||
return 1521;
|
||||
case 'dameng':
|
||||
return 5236;
|
||||
case 'kingbase':
|
||||
return 54321;
|
||||
case 'sqlserver':
|
||||
return 1433;
|
||||
case 'mongodb':
|
||||
return 27017;
|
||||
case 'highgo':
|
||||
return 5866;
|
||||
default:
|
||||
return 3306;
|
||||
}
|
||||
};
|
||||
|
||||
const toTrimmedString = (value: unknown, fallback = ''): string => {
|
||||
if (typeof value === 'string') {
|
||||
return value.trim();
|
||||
}
|
||||
if (typeof value === 'number' || typeof value === 'boolean') {
|
||||
return String(value).trim();
|
||||
}
|
||||
return fallback;
|
||||
};
|
||||
|
||||
const normalizePort = (value: unknown, fallbackPort: number): number => {
|
||||
const parsed = Number(value);
|
||||
if (!Number.isFinite(parsed)) return fallbackPort;
|
||||
const port = Math.trunc(parsed);
|
||||
if (port <= 0 || port > 65535) return fallbackPort;
|
||||
return port;
|
||||
};
|
||||
|
||||
const normalizeIntegerInRange = (value: unknown, fallbackValue: number, min: number, max: number): number => {
|
||||
const parsed = Number(value);
|
||||
if (!Number.isFinite(parsed)) return fallbackValue;
|
||||
const normalized = Math.trunc(parsed);
|
||||
if (normalized < min || normalized > max) return fallbackValue;
|
||||
return normalized;
|
||||
};
|
||||
|
||||
const isValidHostEntry = (entry: string): boolean => {
|
||||
if (!entry) return false;
|
||||
if (entry.length > MAX_HOST_ENTRY_LENGTH) return false;
|
||||
if (/[()\\/\s]/.test(entry)) return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
const sanitizeStringArray = (value: unknown, maxLength = 256): string[] => {
|
||||
if (!Array.isArray(value)) return [];
|
||||
const seen = new Set<string>();
|
||||
const result: string[] = [];
|
||||
value.forEach((entry) => {
|
||||
const normalized = toTrimmedString(entry);
|
||||
if (!normalized || normalized.length > maxLength) return;
|
||||
if (seen.has(normalized)) return;
|
||||
seen.add(normalized);
|
||||
result.push(normalized);
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeNumberArray = (value: unknown, min: number, max: number): number[] => {
|
||||
if (!Array.isArray(value)) return [];
|
||||
const seen = new Set<number>();
|
||||
const result: number[] = [];
|
||||
value.forEach((entry) => {
|
||||
const parsed = Number(entry);
|
||||
if (!Number.isFinite(parsed)) return;
|
||||
const num = Math.trunc(parsed);
|
||||
if (num < min || num > max) return;
|
||||
if (seen.has(num)) return;
|
||||
seen.add(num);
|
||||
result.push(num);
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeAddressList = (value: unknown): string[] => {
|
||||
const all = sanitizeStringArray(value, MAX_HOST_ENTRY_LENGTH)
|
||||
.filter((entry) => isValidHostEntry(entry));
|
||||
return all.slice(0, MAX_HOST_ENTRIES);
|
||||
};
|
||||
|
||||
const normalizeConnectionType = (value: unknown): string => {
|
||||
const type = toTrimmedString(value).toLowerCase();
|
||||
return SUPPORTED_CONNECTION_TYPES.has(type) ? type : DEFAULT_CONNECTION_TYPE;
|
||||
};
|
||||
|
||||
const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const type = normalizeConnectionType(raw.type);
|
||||
const defaultPort = getDefaultPortByType(type);
|
||||
const savePassword = typeof raw.savePassword === 'boolean' ? raw.savePassword : true;
|
||||
const mongoSrv = !!raw.mongoSrv;
|
||||
|
||||
const sshRaw = (raw.ssh && typeof raw.ssh === 'object') ? raw.ssh as Record<string, unknown> : {};
|
||||
const ssh = {
|
||||
host: toTrimmedString(sshRaw.host),
|
||||
port: normalizePort(sshRaw.port, 22),
|
||||
user: toTrimmedString(sshRaw.user),
|
||||
password: toTrimmedString(sshRaw.password),
|
||||
keyPath: toTrimmedString(sshRaw.keyPath),
|
||||
};
|
||||
|
||||
const safeConfig: ConnectionConfig & Record<string, unknown> = {
|
||||
...raw,
|
||||
type,
|
||||
host: toTrimmedString(raw.host, 'localhost') || 'localhost',
|
||||
port: normalizePort(raw.port, defaultPort),
|
||||
user: toTrimmedString(raw.user),
|
||||
password: savePassword ? toTrimmedString(raw.password) : '',
|
||||
savePassword,
|
||||
database: toTrimmedString(raw.database),
|
||||
useSSH: !!raw.useSSH,
|
||||
ssh,
|
||||
uri: toTrimmedString(raw.uri).slice(0, MAX_URI_LENGTH),
|
||||
hosts: sanitizeAddressList(raw.hosts),
|
||||
topology: raw.topology === 'replica' ? 'replica' : 'single',
|
||||
mysqlReplicaUser: toTrimmedString(raw.mysqlReplicaUser),
|
||||
mysqlReplicaPassword: savePassword ? toTrimmedString(raw.mysqlReplicaPassword) : '',
|
||||
replicaSet: toTrimmedString(raw.replicaSet),
|
||||
authSource: toTrimmedString(raw.authSource),
|
||||
readPreference: toTrimmedString(raw.readPreference),
|
||||
mongoSrv,
|
||||
mongoAuthMechanism: toTrimmedString(raw.mongoAuthMechanism),
|
||||
mongoReplicaUser: toTrimmedString(raw.mongoReplicaUser),
|
||||
mongoReplicaPassword: savePassword ? toTrimmedString(raw.mongoReplicaPassword) : '',
|
||||
timeout: normalizeIntegerInRange(raw.timeout, DEFAULT_TIMEOUT_SECONDS, 1, MAX_TIMEOUT_SECONDS),
|
||||
};
|
||||
|
||||
if (type === 'redis') {
|
||||
safeConfig.redisDB = normalizeIntegerInRange(raw.redisDB, 0, 0, 15);
|
||||
}
|
||||
|
||||
if (type === 'custom') {
|
||||
safeConfig.driver = toTrimmedString(raw.driver);
|
||||
safeConfig.dsn = toTrimmedString(raw.dsn).slice(0, MAX_URI_LENGTH);
|
||||
}
|
||||
|
||||
return safeConfig;
|
||||
};
|
||||
|
||||
const sanitizeSavedConnection = (value: unknown, index: number): SavedConnection | null => {
|
||||
if (!value || typeof value !== 'object') return null;
|
||||
const raw = value as Record<string, unknown>;
|
||||
const config = sanitizeConnectionConfig(raw.config);
|
||||
const id = toTrimmedString(raw.id, `conn-${index + 1}`) || `conn-${index + 1}`;
|
||||
const fallbackName = config.host ? `${config.type}-${config.host}` : `连接-${index + 1}`;
|
||||
const name = toTrimmedString(raw.name, fallbackName) || fallbackName;
|
||||
const includeDatabases = sanitizeStringArray(raw.includeDatabases, 256);
|
||||
const includeRedisDatabases = sanitizeNumberArray(raw.includeRedisDatabases, 0, 15);
|
||||
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
config,
|
||||
includeDatabases: includeDatabases.length > 0 ? includeDatabases : undefined,
|
||||
includeRedisDatabases: includeRedisDatabases.length > 0 ? includeRedisDatabases : undefined,
|
||||
};
|
||||
};
|
||||
|
||||
const sanitizeConnections = (value: unknown): SavedConnection[] => {
|
||||
if (!Array.isArray(value)) return [];
|
||||
const result: SavedConnection[] = [];
|
||||
const idSet = new Set<string>();
|
||||
|
||||
value.forEach((entry, index) => {
|
||||
const conn = sanitizeSavedConnection(entry, index);
|
||||
if (!conn) return;
|
||||
let nextId = conn.id;
|
||||
if (idSet.has(nextId)) {
|
||||
nextId = `${nextId}-${index + 1}`;
|
||||
}
|
||||
idSet.add(nextId);
|
||||
result.push({ ...conn, id: nextId });
|
||||
});
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
const isLegacyDefaultAppearance = (appearance: Partial<{ opacity: number; blur: number }> | undefined): boolean => {
|
||||
if (!appearance) {
|
||||
return true;
|
||||
}
|
||||
const opacity = typeof appearance.opacity === 'number' ? appearance.opacity : LEGACY_DEFAULT_OPACITY;
|
||||
const blur = typeof appearance.blur === 'number' ? appearance.blur : 0;
|
||||
return Math.abs(opacity - LEGACY_DEFAULT_OPACITY) < OPACITY_EPSILON && blur === 0;
|
||||
};
|
||||
|
||||
export interface SqlLog {
|
||||
id: string;
|
||||
@@ -24,11 +257,13 @@ interface AppState {
|
||||
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
|
||||
queryOptions: { maxRows: number };
|
||||
sqlLogs: SqlLog[];
|
||||
|
||||
tableAccessCount: Record<string, number>;
|
||||
tableSortPreference: Record<string, 'name' | 'frequency'>;
|
||||
|
||||
addConnection: (conn: SavedConnection) => void;
|
||||
updateConnection: (conn: SavedConnection) => void;
|
||||
removeConnection: (id: string) => void;
|
||||
|
||||
|
||||
addTab: (tab: TabData) => void;
|
||||
closeTab: (id: string) => void;
|
||||
closeOtherTabs: (id: string) => void;
|
||||
@@ -45,11 +280,90 @@ interface AppState {
|
||||
setAppearance: (appearance: Partial<{ opacity: number; blur: number }>) => void;
|
||||
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
|
||||
setQueryOptions: (options: Partial<{ maxRows: number }>) => void;
|
||||
|
||||
|
||||
addSqlLog: (log: SqlLog) => void;
|
||||
clearSqlLogs: () => void;
|
||||
|
||||
recordTableAccess: (connectionId: string, dbName: string, tableName: string) => void;
|
||||
setTableSortPreference: (connectionId: string, dbName: string, sortBy: 'name' | 'frequency') => void;
|
||||
}
|
||||
|
||||
const sanitizeSavedQueries = (value: unknown): SavedQuery[] => {
|
||||
if (!Array.isArray(value)) return [];
|
||||
const result: SavedQuery[] = [];
|
||||
value.forEach((entry, index) => {
|
||||
if (!entry || typeof entry !== 'object') return;
|
||||
const raw = entry as Record<string, unknown>;
|
||||
const id = toTrimmedString(raw.id, `query-${index + 1}`) || `query-${index + 1}`;
|
||||
const sql = toTrimmedString(raw.sql);
|
||||
const connectionId = toTrimmedString(raw.connectionId);
|
||||
const dbName = toTrimmedString(raw.dbName);
|
||||
if (!sql || !connectionId || !dbName) return;
|
||||
result.push({
|
||||
id,
|
||||
name: toTrimmedString(raw.name, `查询-${index + 1}`) || `查询-${index + 1}`,
|
||||
sql,
|
||||
connectionId,
|
||||
dbName,
|
||||
createdAt: Number.isFinite(Number(raw.createdAt)) ? Number(raw.createdAt) : Date.now(),
|
||||
});
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeTheme = (value: unknown): 'light' | 'dark' => (value === 'dark' ? 'dark' : 'light');
|
||||
|
||||
const sanitizeSqlFormatOptions = (value: unknown): { keywordCase: 'upper' | 'lower' } => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
return { keywordCase: raw.keywordCase === 'lower' ? 'lower' : 'upper' };
|
||||
};
|
||||
|
||||
const sanitizeQueryOptions = (value: unknown): { maxRows: number } => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const maxRows = Number(raw.maxRows);
|
||||
if (!Number.isFinite(maxRows) || maxRows <= 0) {
|
||||
return { maxRows: 5000 };
|
||||
}
|
||||
return { maxRows: Math.min(50000, Math.trunc(maxRows)) };
|
||||
};
|
||||
|
||||
const sanitizeTableAccessCount = (value: unknown): Record<string, number> => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const result: Record<string, number> = {};
|
||||
Object.entries(raw).forEach(([key, count]) => {
|
||||
const parsed = Number(count);
|
||||
if (!Number.isFinite(parsed) || parsed < 0) return;
|
||||
result[key] = Math.trunc(parsed);
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeTableSortPreference = (value: unknown): Record<string, 'name' | 'frequency'> => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const result: Record<string, 'name' | 'frequency'> = {};
|
||||
Object.entries(raw).forEach(([key, preference]) => {
|
||||
result[key] = preference === 'frequency' ? 'frequency' : 'name';
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const sanitizeAppearance = (
|
||||
appearance: Partial<{ opacity: number; blur: number }> | undefined,
|
||||
version: number
|
||||
): { opacity: number; blur: number } => {
|
||||
if (!appearance || typeof appearance !== 'object') {
|
||||
return { ...DEFAULT_APPEARANCE };
|
||||
}
|
||||
const nextAppearance = {
|
||||
opacity: typeof appearance.opacity === 'number' ? appearance.opacity : DEFAULT_APPEARANCE.opacity,
|
||||
blur: typeof appearance.blur === 'number' ? appearance.blur : DEFAULT_APPEARANCE.blur,
|
||||
};
|
||||
if (version < 2 && isLegacyDefaultAppearance(appearance)) {
|
||||
return { ...DEFAULT_APPEARANCE };
|
||||
}
|
||||
return nextAppearance;
|
||||
};
|
||||
|
||||
export const useStore = create<AppState>()(
|
||||
persist(
|
||||
(set) => ({
|
||||
@@ -59,14 +373,16 @@ export const useStore = create<AppState>()(
|
||||
activeContext: null,
|
||||
savedQueries: [],
|
||||
theme: 'light',
|
||||
appearance: { opacity: 0.95, blur: 0 },
|
||||
appearance: { ...DEFAULT_APPEARANCE },
|
||||
sqlFormatOptions: { keywordCase: 'upper' },
|
||||
queryOptions: { maxRows: 5000 },
|
||||
sqlLogs: [],
|
||||
tableAccessCount: {},
|
||||
tableSortPreference: {},
|
||||
|
||||
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
|
||||
updateConnection: (conn) => set((state) => ({
|
||||
connections: state.connections.map(c => c.id === conn.id ? conn : c)
|
||||
updateConnection: (conn) => set((state) => ({
|
||||
connections: state.connections.map(c => c.id === conn.id ? conn : c)
|
||||
})),
|
||||
removeConnection: (id) => set((state) => ({ connections: state.connections.filter(c => c.id !== id) })),
|
||||
|
||||
@@ -132,13 +448,77 @@ export const useStore = create<AppState>()(
|
||||
setAppearance: (appearance) => set((state) => ({ appearance: { ...state.appearance, ...appearance } })),
|
||||
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
|
||||
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
|
||||
|
||||
|
||||
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
|
||||
clearSqlLogs: () => set({ sqlLogs: [] }),
|
||||
|
||||
recordTableAccess: (connectionId, dbName, tableName) => set((state) => {
|
||||
const key = `${connectionId}-${dbName}-${tableName}`;
|
||||
const currentCount = state.tableAccessCount[key] || 0;
|
||||
return {
|
||||
tableAccessCount: {
|
||||
...state.tableAccessCount,
|
||||
[key]: currentCount + 1
|
||||
}
|
||||
};
|
||||
}),
|
||||
|
||||
setTableSortPreference: (connectionId, dbName, sortBy) => set((state) => {
|
||||
const key = `${connectionId}-${dbName}`;
|
||||
return {
|
||||
tableSortPreference: {
|
||||
...state.tableSortPreference,
|
||||
[key]: sortBy
|
||||
}
|
||||
};
|
||||
}),
|
||||
}),
|
||||
{
|
||||
name: 'lite-db-storage', // name of the item in the storage (must be unique)
|
||||
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, theme: state.theme, appearance: state.appearance, sqlFormatOptions: state.sqlFormatOptions, queryOptions: state.queryOptions }), // Don't persist logs
|
||||
version: 3,
|
||||
migrate: (persistedState: unknown, version: number) => {
|
||||
if (!persistedState || typeof persistedState !== 'object') {
|
||||
return persistedState as AppState;
|
||||
}
|
||||
const state = persistedState as Partial<AppState>;
|
||||
const nextState: Partial<AppState> = { ...state };
|
||||
nextState.connections = sanitizeConnections(state.connections);
|
||||
nextState.savedQueries = sanitizeSavedQueries(state.savedQueries);
|
||||
nextState.theme = sanitizeTheme(state.theme);
|
||||
nextState.appearance = sanitizeAppearance(state.appearance, version);
|
||||
nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions);
|
||||
nextState.queryOptions = sanitizeQueryOptions(state.queryOptions);
|
||||
nextState.tableAccessCount = sanitizeTableAccessCount(state.tableAccessCount);
|
||||
nextState.tableSortPreference = sanitizeTableSortPreference(state.tableSortPreference);
|
||||
return nextState as AppState;
|
||||
},
|
||||
merge: (persistedState, currentState) => {
|
||||
const state = (persistedState && typeof persistedState === 'object')
|
||||
? persistedState as Partial<AppState>
|
||||
: {};
|
||||
return {
|
||||
...currentState,
|
||||
...state,
|
||||
connections: sanitizeConnections(state.connections),
|
||||
savedQueries: sanitizeSavedQueries(state.savedQueries),
|
||||
theme: sanitizeTheme(state.theme),
|
||||
appearance: sanitizeAppearance(state.appearance, 3),
|
||||
sqlFormatOptions: sanitizeSqlFormatOptions(state.sqlFormatOptions),
|
||||
queryOptions: sanitizeQueryOptions(state.queryOptions),
|
||||
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
|
||||
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
|
||||
};
|
||||
},
|
||||
partialize: (state) => ({
|
||||
connections: state.connections,
|
||||
savedQueries: state.savedQueries,
|
||||
theme: state.theme,
|
||||
appearance: state.appearance,
|
||||
sqlFormatOptions: state.sqlFormatOptions,
|
||||
queryOptions: state.queryOptions,
|
||||
tableAccessCount: state.tableAccessCount,
|
||||
tableSortPreference: state.tableSortPreference
|
||||
}), // Don't persist logs
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
@@ -12,10 +12,35 @@ export interface ConnectionConfig {
|
||||
port: number;
|
||||
user: string;
|
||||
password?: string;
|
||||
savePassword?: boolean;
|
||||
database?: string;
|
||||
useSSH?: boolean;
|
||||
ssh?: SSHConfig;
|
||||
driver?: string;
|
||||
dsn?: string;
|
||||
timeout?: number;
|
||||
redisDB?: number; // Redis database index (0-15)
|
||||
uri?: string; // Connection URI for copy/paste
|
||||
hosts?: string[]; // Multi-host addresses: host:port
|
||||
topology?: 'single' | 'replica';
|
||||
mysqlReplicaUser?: string;
|
||||
mysqlReplicaPassword?: string;
|
||||
replicaSet?: string;
|
||||
authSource?: string;
|
||||
readPreference?: string;
|
||||
mongoSrv?: boolean;
|
||||
mongoAuthMechanism?: string;
|
||||
mongoReplicaUser?: string;
|
||||
mongoReplicaPassword?: string;
|
||||
}
|
||||
|
||||
export interface MongoMemberInfo {
|
||||
host: string;
|
||||
role: string;
|
||||
state: string;
|
||||
stateCode?: number;
|
||||
healthy: boolean;
|
||||
isSelf?: boolean;
|
||||
}
|
||||
|
||||
export interface SavedConnection {
|
||||
@@ -62,7 +87,7 @@ export interface TriggerDefinition {
|
||||
export interface TabData {
|
||||
id: string;
|
||||
title: string;
|
||||
type: 'query' | 'table' | 'design' | 'redis-keys' | 'redis-command';
|
||||
type: 'query' | 'table' | 'design' | 'redis-keys' | 'redis-command' | 'trigger' | 'view-def' | 'routine-def';
|
||||
connectionId: string;
|
||||
dbName?: string;
|
||||
tableName?: string;
|
||||
@@ -70,6 +95,10 @@ export interface TabData {
|
||||
initialTab?: string;
|
||||
readOnly?: boolean;
|
||||
redisDB?: number; // Redis database index for redis tabs
|
||||
triggerName?: string; // Trigger name for trigger tabs
|
||||
viewName?: string; // View name for view definition tabs
|
||||
routineName?: string; // Routine name for function/procedure definition tabs
|
||||
routineType?: string; // 'FUNCTION' or 'PROCEDURE'
|
||||
}
|
||||
|
||||
export interface DatabaseNode {
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
const DEFAULT_OPACITY = 0.95;
|
||||
const DEFAULT_OPACITY = 1.0;
|
||||
const MIN_OPACITY = 0.1;
|
||||
const MAX_OPACITY = 1.0;
|
||||
|
||||
// macOS 端进一步增强通透感:同滑块值下更低等效不透明度、降低过重模糊。
|
||||
const MAC_OPACITY_FACTOR = 0.20;
|
||||
// 平台透明度映射因子:值越大,滑块变化越平滑(1.0 = 线性映射)
|
||||
const MAC_OPACITY_FACTOR = 0.60;
|
||||
const MAC_BLUR_FACTOR = 1.00;
|
||||
const WINDOWS_OPACITY_FACTOR = 0.20;
|
||||
const WINDOWS_OPACITY_FACTOR = 0.70;
|
||||
const WINDOWS_BLUR_FACTOR = 1.00;
|
||||
|
||||
const clamp = (value: number, min: number, max: number) => Math.min(max, Math.max(min, value));
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
export type FilterCondition = {
|
||||
id?: number;
|
||||
enabled?: boolean;
|
||||
column?: string;
|
||||
op?: string;
|
||||
value?: string;
|
||||
@@ -23,6 +24,8 @@ const needsQuote = (ident: string): boolean => {
|
||||
if (!ident) return false;
|
||||
// 如果包含特殊字符(非字母、数字、下划线)则需要引号
|
||||
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(ident)) return true;
|
||||
// PostgreSQL 会将未加引号的标识符折叠为小写,含大写字母时必须加引号
|
||||
if (/[A-Z]/.test(ident)) return true;
|
||||
// 常见 SQL 保留字列表(简化版)
|
||||
const reserved = ['select', 'from', 'where', 'table', 'index', 'user', 'order', 'group', 'by', 'limit', 'offset', 'and', 'or', 'not', 'null', 'true', 'false', 'key', 'primary', 'foreign', 'references', 'default', 'constraint', 'create', 'drop', 'alter', 'insert', 'update', 'delete', 'set', 'values', 'into', 'join', 'left', 'right', 'inner', 'outer', 'on', 'as', 'is', 'in', 'like', 'between', 'case', 'when', 'then', 'else', 'end', 'having', 'distinct', 'all', 'any', 'exists', 'union', 'except', 'intersect'];
|
||||
return reserved.includes(ident.toLowerCase());
|
||||
@@ -33,7 +36,7 @@ export const quoteIdentPart = (dbType: string, ident: string) => {
|
||||
if (!raw) return raw;
|
||||
const dbTypeLower = (dbType || '').toLowerCase();
|
||||
|
||||
if (dbTypeLower === 'mysql') {
|
||||
if (dbTypeLower === 'mysql' || dbTypeLower === 'mariadb' || dbTypeLower === 'sphinx' || dbTypeLower === 'tdengine') {
|
||||
return `\`${raw.replace(/`/g, '``')}\``;
|
||||
}
|
||||
|
||||
@@ -60,6 +63,41 @@ export const quoteQualifiedIdent = (dbType: string, ident: string) => {
|
||||
|
||||
export const escapeLiteral = (val: string) => (val || '').replace(/'/g, "''");
|
||||
|
||||
type SortInfo = {
|
||||
columnKey?: string;
|
||||
order?: string;
|
||||
} | null | undefined;
|
||||
|
||||
export const buildOrderBySQL = (
|
||||
dbType: string,
|
||||
sortInfo: SortInfo,
|
||||
fallbackColumns: string[] = [],
|
||||
) => {
|
||||
const sortColumn = normalizeIdentPart(String(sortInfo?.columnKey || ''));
|
||||
const sortOrder = String(sortInfo?.order || '');
|
||||
const direction = sortOrder === 'ascend' ? 'ASC' : sortOrder === 'descend' ? 'DESC' : '';
|
||||
if (sortColumn && direction) {
|
||||
return ` ORDER BY ${quoteIdentPart(dbType, sortColumn)} ${direction}`;
|
||||
}
|
||||
|
||||
const seen = new Set<string>();
|
||||
const stableColumns = (fallbackColumns || [])
|
||||
.map((col) => normalizeIdentPart(String(col || '')))
|
||||
.filter((col) => {
|
||||
if (!col) return false;
|
||||
const key = col.toLowerCase();
|
||||
if (seen.has(key)) return false;
|
||||
seen.add(key);
|
||||
return true;
|
||||
});
|
||||
if (stableColumns.length > 0) {
|
||||
const parts = stableColumns.map((col) => `${quoteIdentPart(dbType, col)} ASC`);
|
||||
return ` ORDER BY ${parts.join(', ')}`;
|
||||
}
|
||||
|
||||
return '';
|
||||
};
|
||||
|
||||
export const parseListValues = (val: string) => {
|
||||
const raw = (val || '').trim();
|
||||
if (!raw) return [];
|
||||
@@ -73,6 +111,8 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
|
||||
const whereParts: string[] = [];
|
||||
|
||||
(conditions || []).forEach((cond) => {
|
||||
if (cond?.enabled === false) return;
|
||||
|
||||
const op = (cond?.op || '').trim();
|
||||
const column = (cond?.column || '').trim();
|
||||
const value = (cond?.value ?? '').toString();
|
||||
@@ -195,4 +235,3 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
|
||||
|
||||
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
|
||||
};
|
||||
|
||||
|
||||
24
frontend/wailsjs/go/app/App.d.ts
vendored
24
frontend/wailsjs/go/app/App.d.ts
vendored
@@ -38,6 +38,14 @@ export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Pr
|
||||
|
||||
export function DownloadUpdate():Promise<connection.QueryResult>;
|
||||
|
||||
export function DropDatabase(arg1:connection.ConnectionConfig,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DropFunction(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DropTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DropView(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise<connection.QueryResult>;
|
||||
@@ -46,6 +54,8 @@ export function ExportQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:st
|
||||
|
||||
export function ExportTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportTablesDataSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportTablesSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>,arg4:boolean):Promise<connection.QueryResult>;
|
||||
|
||||
export function GetAppInfo():Promise<connection.QueryResult>;
|
||||
@@ -54,8 +64,12 @@ export function ImportConfigFile():Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportDataWithProgress(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function InstallUpdateAndRestart():Promise<connection.QueryResult>;
|
||||
|
||||
export function MongoDiscoverMembers(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function MySQLConnect(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function MySQLGetDatabases(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
@@ -68,6 +82,8 @@ export function MySQLShowCreateTable(arg1:connection.ConnectionConfig,arg2:strin
|
||||
|
||||
export function OpenSQLFile():Promise<connection.QueryResult>;
|
||||
|
||||
export function PreviewImportFile(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisConnect(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function RedisDeleteHashField(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
@@ -110,4 +126,12 @@ export function RedisZSetAdd(arg1:connection.ConnectionConfig,arg2:string,arg3:A
|
||||
|
||||
export function RedisZSetRemove(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
export function RenameDatabase(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RenameTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function RenameView(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function SetWindowTranslucency(arg1:number,arg2:number):Promise<void>;
|
||||
|
||||
export function TestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -70,6 +70,22 @@ export function DownloadUpdate() {
|
||||
return window['go']['app']['App']['DownloadUpdate']();
|
||||
}
|
||||
|
||||
export function DropDatabase(arg1, arg2) {
|
||||
return window['go']['app']['App']['DropDatabase'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function DropFunction(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['DropFunction'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function DropTable(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['DropTable'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function DropView(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['DropView'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function ExportData(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@@ -86,6 +102,10 @@ export function ExportTable(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportTable'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ExportTablesDataSQL(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['ExportTablesDataSQL'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function ExportTablesSQL(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportTablesSQL'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@@ -102,10 +122,18 @@ export function ImportData(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['ImportData'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function ImportDataWithProgress(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ImportDataWithProgress'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function InstallUpdateAndRestart() {
|
||||
return window['go']['app']['App']['InstallUpdateAndRestart']();
|
||||
}
|
||||
|
||||
export function MongoDiscoverMembers(arg1) {
|
||||
return window['go']['app']['App']['MongoDiscoverMembers'](arg1);
|
||||
}
|
||||
|
||||
export function MySQLConnect(arg1) {
|
||||
return window['go']['app']['App']['MySQLConnect'](arg1);
|
||||
}
|
||||
@@ -130,6 +158,10 @@ export function OpenSQLFile() {
|
||||
return window['go']['app']['App']['OpenSQLFile']();
|
||||
}
|
||||
|
||||
export function PreviewImportFile(arg1) {
|
||||
return window['go']['app']['App']['PreviewImportFile'](arg1);
|
||||
}
|
||||
|
||||
export function RedisConnect(arg1) {
|
||||
return window['go']['app']['App']['RedisConnect'](arg1);
|
||||
}
|
||||
@@ -214,6 +246,22 @@ export function RedisZSetRemove(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RedisZSetRemove'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RenameDatabase(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['RenameDatabase'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function RenameTable(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['RenameTable'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function RenameView(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['RenameView'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function SetWindowTranslucency(arg1, arg2) {
|
||||
return window['go']['app']['App']['SetWindowTranslucency'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function TestConnection(arg1) {
|
||||
return window['go']['app']['App']['TestConnection'](arg1);
|
||||
}
|
||||
|
||||
@@ -74,6 +74,7 @@ export namespace connection {
|
||||
port: number;
|
||||
user: string;
|
||||
password: string;
|
||||
savePassword?: boolean;
|
||||
database: string;
|
||||
useSSH: boolean;
|
||||
ssh: SSHConfig;
|
||||
@@ -81,6 +82,18 @@ export namespace connection {
|
||||
dsn?: string;
|
||||
timeout?: number;
|
||||
redisDB?: number;
|
||||
uri?: string;
|
||||
hosts?: string[];
|
||||
topology?: string;
|
||||
mysqlReplicaUser?: string;
|
||||
mysqlReplicaPassword?: string;
|
||||
replicaSet?: string;
|
||||
authSource?: string;
|
||||
readPreference?: string;
|
||||
mongoSrv?: boolean;
|
||||
mongoAuthMechanism?: string;
|
||||
mongoReplicaUser?: string;
|
||||
mongoReplicaPassword?: string;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new ConnectionConfig(source);
|
||||
@@ -93,6 +106,7 @@ export namespace connection {
|
||||
this.port = source["port"];
|
||||
this.user = source["user"];
|
||||
this.password = source["password"];
|
||||
this.savePassword = source["savePassword"];
|
||||
this.database = source["database"];
|
||||
this.useSSH = source["useSSH"];
|
||||
this.ssh = this.convertValues(source["ssh"], SSHConfig);
|
||||
@@ -100,6 +114,18 @@ export namespace connection {
|
||||
this.dsn = source["dsn"];
|
||||
this.timeout = source["timeout"];
|
||||
this.redisDB = source["redisDB"];
|
||||
this.uri = source["uri"];
|
||||
this.hosts = source["hosts"];
|
||||
this.topology = source["topology"];
|
||||
this.mysqlReplicaUser = source["mysqlReplicaUser"];
|
||||
this.mysqlReplicaPassword = source["mysqlReplicaPassword"];
|
||||
this.replicaSet = source["replicaSet"];
|
||||
this.authSource = source["authSource"];
|
||||
this.readPreference = source["readPreference"];
|
||||
this.mongoSrv = source["mongoSrv"];
|
||||
this.mongoAuthMechanism = source["mongoAuthMechanism"];
|
||||
this.mongoReplicaUser = source["mongoReplicaUser"];
|
||||
this.mongoReplicaPassword = source["mongoReplicaPassword"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
|
||||
27
go.mod
27
go.mod
@@ -6,11 +6,17 @@ require (
|
||||
gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3
|
||||
gitee.com/chunanyong/dm v1.8.22
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/highgo/pq-sm3 v0.0.0
|
||||
github.com/lib/pq v1.11.1
|
||||
github.com/microsoft/go-mssqldb v1.9.6
|
||||
github.com/redis/go-redis/v9 v9.17.3
|
||||
github.com/sijms/go-ora/v2 v2.9.0
|
||||
github.com/taosdata/driver-go/v3 v3.7.8
|
||||
github.com/wailsapp/wails/v2 v2.11.0
|
||||
github.com/xuri/excelize/v2 v2.10.0
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0
|
||||
golang.org/x/crypto v0.47.0
|
||||
golang.org/x/text v0.33.0
|
||||
modernc.org/sqlite v1.44.3
|
||||
)
|
||||
|
||||
@@ -22,10 +28,15 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/hashicorp/go-version v1.7.0 // indirect
|
||||
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/labstack/echo/v4 v4.13.3 // indirect
|
||||
github.com/labstack/gommon v0.4.2 // indirect
|
||||
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
|
||||
@@ -34,22 +45,36 @@ require (
|
||||
github.com/leaanthony/u v1.1.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/richardlehane/mscfb v1.0.4 // indirect
|
||||
github.com/richardlehane/msoleps v1.0.4 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/samber/lo v1.49.1 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/tiendc/go-deepcopy v1.7.1 // indirect
|
||||
github.com/tkrajina/go-reflector v0.5.8 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
github.com/wailsapp/go-webview2 v1.0.22 // indirect
|
||||
github.com/wailsapp/mimetype v1.4.1 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.2.0 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/xuri/efp v0.0.1 // indirect
|
||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
)
|
||||
|
||||
replace github.com/highgo/pq-sm3 => ./third_party/highgo-pq
|
||||
|
||||
101
go.sum
101
go.sum
@@ -4,6 +4,18 @@ gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3 h1:QjslQNaH5Nuap5i4ni
|
||||
gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3/go.mod h1:7lH5A1jzCXD9Nl16DzaBUOfDAT8NPrDmZwKu1p5wf94=
|
||||
gitee.com/chunanyong/dm v1.8.22 h1:H7fsrnUIvEA0jlDWew7vwELry1ff+tLMIu2Fk2cIBSg=
|
||||
gitee.com/chunanyong/dm v1.8.22/go.mod h1:EPRJnuPFgbyOFgJ0TRYCTGzhq+ZT4wdyaj/GW/LLcNg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY=
|
||||
github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
@@ -12,6 +24,7 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
@@ -24,19 +37,37 @@ github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
|
||||
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e h1:Q3+PugElBCf4PFpxhErSzU3/PY5sFL5Z6rfv4AbGAck=
|
||||
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e/go.mod h1:alcuEEnZsY1WQsagKhZDsoPCRoOijYqhZvPwLG0kzVs=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY=
|
||||
github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g=
|
||||
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
|
||||
@@ -61,6 +92,12 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/microsoft/go-mssqldb v1.9.6 h1:1MNQg5UiSsokiPz3++K2KPx4moKrwIqly1wv+RyCKTw=
|
||||
github.com/microsoft/go-mssqldb v1.9.6/go.mod h1:yYMPDufyoF2vVuVCUGtZARr06DKFIhMrluTcgWlXpr4=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
@@ -73,15 +110,34 @@ github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1D
|
||||
github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/richardlehane/mscfb v1.0.4 h1:WULscsljNPConisD5hR0+OyZjwK46Pfyr6mPu5ZawpM=
|
||||
github.com/richardlehane/mscfb v1.0.4/go.mod h1:YzVpcZg9czvAuhk9T+a3avCpcFPMUWm7gK3DypaEsUk=
|
||||
github.com/richardlehane/msoleps v1.0.1/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
|
||||
github.com/richardlehane/msoleps v1.0.4 h1:WuESlvhX3gH2IHcd8UqyCuFY5yiq/GR/yqaSM/9/g00=
|
||||
github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/samber/lo v1.49.1 h1:4BIFyVfuQSEpluc7Fua+j1NolZHiEHEpaSEKdsH0tew=
|
||||
github.com/samber/lo v1.49.1/go.mod h1:dO6KHFzUKXgP8LDhU0oI8d2hekjXnGOu0DB8Jecxd6o=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sijms/go-ora/v2 v2.9.0 h1:+iQbUeTeCOFMb5BsOMgUhV8KWyrv9yjKpcK4x7+MFrg=
|
||||
github.com/sijms/go-ora/v2 v2.9.0/go.mod h1:QgFInVi3ZWyqAiJwzBQA+nbKYKH77tdp1PYoCqhR2dU=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/taosdata/driver-go/v3 v3.7.8 h1:N2H6HLLZH2ve2ipcoFgG9BJS+yW0XksqNYwEdSmHaJk=
|
||||
github.com/taosdata/driver-go/v3 v3.7.8/go.mod h1:gSxBEPOueMg0rTmMO1Ug6aeD7AwGdDGvUtLrsDTTpYc=
|
||||
github.com/tiendc/go-deepcopy v1.7.1 h1:LnubftI6nYaaMOcaz0LphzwraqN8jiWTwm416sitff4=
|
||||
github.com/tiendc/go-deepcopy v1.7.1/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ=
|
||||
github.com/tkrajina/go-reflector v0.5.8 h1:yPADHrwmUbMq4RGEyaOUpz2H90sRsETNVpjzo3DLVQQ=
|
||||
github.com/tkrajina/go-reflector v0.5.8/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
@@ -94,35 +150,76 @@ github.com/wailsapp/mimetype v1.4.1 h1:pQN9ycO7uo4vsUUuPeHEYoUkLVkaRntMnHJxVwYhw
|
||||
github.com/wailsapp/mimetype v1.4.1/go.mod h1:9aV5k31bBOv5z6u+QP8TltzvNGJPmNJD4XlAL3U+j3o=
|
||||
github.com/wailsapp/wails/v2 v2.11.0 h1:seLacV8pqupq32IjS4Y7V8ucab0WZwtK6VvUVxSBtqQ=
|
||||
github.com/wailsapp/wails/v2 v2.11.0/go.mod h1:jrf0ZaM6+GBc1wRmXsM8cIvzlg0karYin3erahI4+0k=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/xuri/efp v0.0.1 h1:fws5Rv3myXyYni8uwj2qKjVaRP30PdjeYe2Y6FDsCL8=
|
||||
github.com/xuri/efp v0.0.1/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI=
|
||||
github.com/xuri/excelize/v2 v2.10.0 h1:8aKsP7JD39iKLc6dH5Tw3dgV3sPRh8uRVXu/fMstfW4=
|
||||
github.com/xuri/excelize/v2 v2.10.0/go.mod h1:SC5TzhQkaOsTWpANfm+7bJCldzcnU/jrhqkTi/iBHBU=
|
||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 h1:+C0TIdyyYmzadGaL/HBLbf3WdLgC29pgyhTjAT/0nuE=
|
||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/image v0.25.0 h1:Y6uW6rH1y5y/LK1J8BPWZtr6yZ7hrsy6hFrXjgsc2fQ=
|
||||
golang.org/x/image v0.25.0/go.mod h1:tCAmOEGthTtkalusGp1g3xa2gke8J6c2N565dTyl9Rs=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210505024714-0287a6fb4125/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20200810151505-1b9f1253b3ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
|
||||
@@ -49,6 +49,13 @@ func (a *App) Startup(ctx context.Context) {
|
||||
logger.Infof("应用启动完成")
|
||||
}
|
||||
|
||||
// SetWindowTranslucency 动态调整 macOS 窗口透明度。
|
||||
// 前端在加载用户外观设置后、以及用户修改外观时调用此方法。
|
||||
// opacity=1.0 且 blur=0 时窗口标记为 opaque,GPU 不再持续计算窗口背后的模糊合成。
|
||||
func (a *App) SetWindowTranslucency(opacity float64, blur float64) {
|
||||
setMacWindowTranslucency(opacity, blur)
|
||||
}
|
||||
|
||||
// Shutdown is called when the app terminates
|
||||
func (a *App) Shutdown(ctx context.Context) {
|
||||
logger.Infof("应用开始关闭,准备释放资源")
|
||||
@@ -103,10 +110,11 @@ type withLogHint struct {
|
||||
}
|
||||
|
||||
func (e withLogHint) Error() string {
|
||||
message := normalizeErrorMessage(e.err)
|
||||
if strings.TrimSpace(e.logPath) == "" {
|
||||
return e.err.Error()
|
||||
return message
|
||||
}
|
||||
return fmt.Sprintf("%s(详细日志:%s)", e.err.Error(), e.logPath)
|
||||
return fmt.Sprintf("%s(详细日志:%s)", message, e.logPath)
|
||||
}
|
||||
|
||||
func (e withLogHint) Unwrap() error {
|
||||
@@ -128,6 +136,33 @@ func formatConnSummary(config connection.ConnectionConfig) string {
|
||||
b.WriteString(fmt.Sprintf("类型=%s 地址=%s:%d 数据库=%s 用户=%s 超时=%ds",
|
||||
config.Type, config.Host, config.Port, dbName, config.User, timeoutSeconds))
|
||||
|
||||
if len(config.Hosts) > 0 {
|
||||
b.WriteString(fmt.Sprintf(" 节点数=%d", len(config.Hosts)))
|
||||
}
|
||||
if strings.TrimSpace(config.Topology) != "" {
|
||||
b.WriteString(fmt.Sprintf(" 拓扑=%s", strings.TrimSpace(config.Topology)))
|
||||
}
|
||||
if strings.TrimSpace(config.URI) != "" {
|
||||
b.WriteString(fmt.Sprintf(" URI=已配置(长度=%d)", len(config.URI)))
|
||||
}
|
||||
if strings.TrimSpace(config.MySQLReplicaUser) != "" {
|
||||
b.WriteString(" MySQL从库凭据=已配置")
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(config.Type), "mongodb") {
|
||||
if strings.TrimSpace(config.MongoReplicaUser) != "" {
|
||||
b.WriteString(" Mongo从库凭据=已配置")
|
||||
}
|
||||
if strings.TrimSpace(config.ReplicaSet) != "" {
|
||||
b.WriteString(fmt.Sprintf(" 副本集=%s", strings.TrimSpace(config.ReplicaSet)))
|
||||
}
|
||||
if strings.TrimSpace(config.ReadPreference) != "" {
|
||||
b.WriteString(fmt.Sprintf(" 读偏好=%s", strings.TrimSpace(config.ReadPreference)))
|
||||
}
|
||||
if strings.TrimSpace(config.AuthSource) != "" {
|
||||
b.WriteString(fmt.Sprintf(" 认证库=%s", strings.TrimSpace(config.AuthSource)))
|
||||
}
|
||||
}
|
||||
|
||||
if config.UseSSH {
|
||||
b.WriteString(fmt.Sprintf(" SSH=%s:%d 用户=%s", config.SSH.Host, config.SSH.Port, config.SSH.User))
|
||||
}
|
||||
|
||||
@@ -14,8 +14,8 @@ func normalizeRunConfig(config connection.ConnectionConfig, dbName string) conne
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.Type)) {
|
||||
case "mysql", "postgres", "kingbase":
|
||||
// 这些类型的 dbName 表示“数据库”,需要写入连接配置以选择目标库。
|
||||
case "mysql", "mariadb", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "sqlserver", "mongodb", "tdengine":
|
||||
// 这些类型的 dbName 表示"数据库",需要写入连接配置以选择目标库。
|
||||
runConfig.Database = name
|
||||
case "dameng":
|
||||
// 达梦使用 schema 参数,沿用现有行为:dbName 表示 schema。
|
||||
@@ -45,12 +45,14 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.Type)) {
|
||||
case "postgres", "kingbase":
|
||||
// PG/金仓:dbName 在 UI 里是“数据库”,schema 需从 tableName 或使用默认 public。
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
// PG/金仓/瀚高/海量:dbName 在 UI 里是"数据库",schema 需从 tableName 或使用默认 public。
|
||||
return "public", rawTable
|
||||
case "sqlserver":
|
||||
// SQL Server:dbName 表示数据库,schema 默认 dbo
|
||||
return "dbo", rawTable
|
||||
default:
|
||||
// MySQL:dbName 表示数据库;Oracle/达梦:dbName 表示 schema/owner。
|
||||
return rawDB, rawTable
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
100
internal/app/error_text.go
Normal file
100
internal/app/error_text.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
func normalizeErrorMessage(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeMixedEncodingText(err.Error())
|
||||
}
|
||||
|
||||
func normalizeMixedEncodingText(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
raw := []byte(text)
|
||||
output := make([]byte, 0, len(raw)+16)
|
||||
suspect := make([]byte, 0, 16)
|
||||
|
||||
flushSuspect := func() {
|
||||
if len(suspect) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fallback := strings.ToValidUTF8(string(suspect), "<22>")
|
||||
decoded, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), suspect)
|
||||
if err == nil && utf8.Valid(decoded) {
|
||||
candidate := string(decoded)
|
||||
if scoreDecodedText(candidate) > scoreDecodedText(fallback) {
|
||||
output = append(output, []byte(candidate)...)
|
||||
} else {
|
||||
output = append(output, []byte(fallback)...)
|
||||
}
|
||||
} else {
|
||||
output = append(output, []byte(fallback)...)
|
||||
}
|
||||
|
||||
suspect = suspect[:0]
|
||||
}
|
||||
|
||||
for len(raw) > 0 {
|
||||
r, size := utf8.DecodeRune(raw)
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
suspect = append(suspect, raw[0])
|
||||
raw = raw[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
if isLikelyMojibakeRune(r) {
|
||||
suspect = append(suspect, raw[:size]...)
|
||||
} else {
|
||||
flushSuspect()
|
||||
output = append(output, raw[:size]...)
|
||||
}
|
||||
raw = raw[size:]
|
||||
}
|
||||
|
||||
flushSuspect()
|
||||
return string(output)
|
||||
}
|
||||
|
||||
func isLikelyMojibakeRune(r rune) bool {
|
||||
if r == utf8.RuneError {
|
||||
return true
|
||||
}
|
||||
if r >= 0x00C0 && r <= 0x02FF {
|
||||
return true
|
||||
}
|
||||
if unicode.In(r, unicode.Hebrew, unicode.Arabic, unicode.Cyrillic, unicode.Greek) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func scoreDecodedText(text string) int {
|
||||
score := 0
|
||||
for _, r := range text {
|
||||
switch {
|
||||
case r == '<27>':
|
||||
score -= 6
|
||||
case unicode.Is(unicode.Han, r):
|
||||
score += 4
|
||||
case isLikelyMojibakeRune(r):
|
||||
score -= 3
|
||||
case unicode.IsPrint(r):
|
||||
score += 1
|
||||
default:
|
||||
score -= 2
|
||||
}
|
||||
}
|
||||
return score
|
||||
}
|
||||
25
internal/app/error_text_test.go
Normal file
25
internal/app/error_text_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package app
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeMixedEncodingText_GBKErrorMessage(t *testing.T) {
|
||||
raw := []byte("pq: ")
|
||||
raw = append(raw, 0xD3, 0xC3, 0xBB, 0xA7) // 用户
|
||||
raw = append(raw, []byte(` "root" Password `)...)
|
||||
raw = append(raw, 0xC8, 0xCF, 0xD6, 0xA4, 0xCA, 0xA7, 0xB0, 0xDC) // 认证失败
|
||||
raw = append(raw, []byte(" (28P01)")...)
|
||||
|
||||
got := normalizeMixedEncodingText(string(raw))
|
||||
want := `pq: 用户 "root" Password 认证失败 (28P01)`
|
||||
if got != want {
|
||||
t.Fatalf("normalizeMixedEncodingText() mismatch\nwant: %q\ngot: %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMixedEncodingText_KeepUTF8(t *testing.T) {
|
||||
input := `连接建立后验证失败:pq: password authentication failed for user "root"`
|
||||
got := normalizeMixedEncodingText(input)
|
||||
if got != input {
|
||||
t.Fatalf("expected unchanged utf8 text, got: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@ func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResu
|
||||
logger.Error(err, "DBConnect 连接失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
|
||||
logger.Infof("DBConnect 连接成功:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: true, Message: "连接成功"}
|
||||
}
|
||||
@@ -31,14 +31,49 @@ func (a *App) TestConnection(config connection.ConnectionConfig) connection.Quer
|
||||
logger.Error(err, "TestConnection 连接测试失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
|
||||
logger.Infof("TestConnection 连接测试成功:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: true, Message: "连接成功"}
|
||||
}
|
||||
|
||||
func (a *App) MongoDiscoverMembers(config connection.ConnectionConfig) connection.QueryResult {
|
||||
config.Type = "mongodb"
|
||||
|
||||
dbInst, err := a.getDatabaseForcePing(config)
|
||||
if err != nil {
|
||||
logger.Error(err, "MongoDiscoverMembers 获取连接失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
discoverable, ok := dbInst.(interface {
|
||||
DiscoverMembers() (string, []connection.MongoMemberInfo, error)
|
||||
})
|
||||
if !ok {
|
||||
return connection.QueryResult{Success: false, Message: "当前 MongoDB 驱动不支持成员发现"}
|
||||
}
|
||||
|
||||
replicaSet, members, err := discoverable.DiscoverMembers()
|
||||
if err != nil {
|
||||
logger.Error(err, "MongoDiscoverMembers 执行失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"replicaSet": replicaSet,
|
||||
"members": members,
|
||||
}
|
||||
|
||||
logger.Infof("MongoDiscoverMembers 成功:%s 成员数=%d 副本集=%s", formatConnSummary(config), len(members), replicaSet)
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("发现 %d 个成员", len(members)),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string) connection.QueryResult {
|
||||
runConfig := config
|
||||
runConfig.Database = ""
|
||||
runConfig.Database = ""
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
@@ -47,9 +82,16 @@ func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string)
|
||||
|
||||
escapedDbName := strings.ReplaceAll(dbName, "`", "``")
|
||||
query := fmt.Sprintf("CREATE DATABASE `%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", escapedDbName)
|
||||
if runConfig.Type == "postgres" {
|
||||
dbType := strings.ToLower(strings.TrimSpace(runConfig.Type))
|
||||
if dbType == "postgres" || dbType == "kingbase" || dbType == "highgo" || dbType == "vastbase" {
|
||||
escapedDbName = strings.ReplaceAll(dbName, `"`, `""`)
|
||||
query = fmt.Sprintf("CREATE DATABASE \"%s\"", escapedDbName)
|
||||
} else if dbType == "tdengine" {
|
||||
query = fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", quoteIdentByType(dbType, dbName))
|
||||
} else if dbType == "mariadb" {
|
||||
// MariaDB uses same syntax as MySQL
|
||||
} else if dbType == "sphinx" {
|
||||
return connection.QueryResult{Success: false, Message: "Sphinx 暂不支持创建数据库"}
|
||||
}
|
||||
|
||||
_, err = dbInst.Exec(query)
|
||||
@@ -60,6 +102,232 @@ func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string)
|
||||
return connection.QueryResult{Success: true, Message: "Database created successfully"}
|
||||
}
|
||||
|
||||
func resolveDDLDBType(config connection.ConnectionConfig) string {
|
||||
dbType := strings.ToLower(strings.TrimSpace(config.Type))
|
||||
if dbType != "custom" {
|
||||
return dbType
|
||||
}
|
||||
|
||||
driver := strings.ToLower(strings.TrimSpace(config.Driver))
|
||||
switch driver {
|
||||
case "postgresql":
|
||||
return "postgres"
|
||||
case "dm":
|
||||
return "dameng"
|
||||
case "sqlite3":
|
||||
return "sqlite"
|
||||
case "sphinxql":
|
||||
return "sphinx"
|
||||
default:
|
||||
return driver
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSchemaAndTableByType(dbType string, dbName string, tableName string) (string, string) {
|
||||
rawTable := strings.TrimSpace(tableName)
|
||||
rawDB := strings.TrimSpace(dbName)
|
||||
if rawTable == "" {
|
||||
return rawDB, rawTable
|
||||
}
|
||||
|
||||
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
|
||||
schema := strings.TrimSpace(parts[0])
|
||||
table := strings.TrimSpace(parts[1])
|
||||
if schema != "" && table != "" {
|
||||
return schema, table
|
||||
}
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
return "public", rawTable
|
||||
default:
|
||||
return rawDB, rawTable
|
||||
}
|
||||
}
|
||||
|
||||
func quoteTableIdentByType(dbType string, schema string, table string) string {
|
||||
s := strings.TrimSpace(schema)
|
||||
t := strings.TrimSpace(table)
|
||||
if s == "" {
|
||||
return quoteIdentByType(dbType, t)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", quoteIdentByType(dbType, s), quoteIdentByType(dbType, t))
|
||||
}
|
||||
|
||||
func buildRunConfigForDDL(config connection.ConnectionConfig, dbType string, dbName string) connection.ConnectionConfig {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
if strings.EqualFold(strings.TrimSpace(config.Type), "custom") {
|
||||
// custom 连接的 dbName 语义依赖 driver,尽量在常见驱动上对齐内置类型行为。
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx", "postgres", "kingbase", "vastbase", "dameng":
|
||||
if strings.TrimSpace(dbName) != "" {
|
||||
runConfig.Database = strings.TrimSpace(dbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return runConfig
|
||||
}
|
||||
|
||||
func (a *App) RenameDatabase(config connection.ConnectionConfig, oldName string, newName string) connection.QueryResult {
|
||||
oldName = strings.TrimSpace(oldName)
|
||||
newName = strings.TrimSpace(newName)
|
||||
if oldName == "" || newName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "数据库名称不能为空"}
|
||||
}
|
||||
if strings.EqualFold(oldName, newName) {
|
||||
return connection.QueryResult{Success: false, Message: "新旧数据库名称不能相同"}
|
||||
}
|
||||
|
||||
dbType := resolveDDLDBType(config)
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx":
|
||||
return connection.QueryResult{Success: false, Message: "MySQL/MariaDB/Sphinx 不支持直接重命名数据库,请新建库后迁移数据"}
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
if strings.EqualFold(strings.TrimSpace(config.Database), oldName) {
|
||||
return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再重命名"}
|
||||
}
|
||||
runConfig := config
|
||||
if strings.TrimSpace(runConfig.Database) == "" {
|
||||
runConfig.Database = "postgres"
|
||||
}
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
sql := fmt.Sprintf("ALTER DATABASE %s RENAME TO %s", quoteIdentByType(dbType, oldName), quoteIdentByType(dbType, newName))
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "数据库重命名成功"}
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持重命名数据库", dbType)}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) DropDatabase(config connection.ConnectionConfig, dbName string) connection.QueryResult {
|
||||
dbName = strings.TrimSpace(dbName)
|
||||
if dbName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "数据库名称不能为空"}
|
||||
}
|
||||
|
||||
dbType := resolveDDLDBType(config)
|
||||
var (
|
||||
runConfig connection.ConnectionConfig
|
||||
sql string
|
||||
)
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "tdengine":
|
||||
runConfig = config
|
||||
runConfig.Database = ""
|
||||
sql = fmt.Sprintf("DROP DATABASE %s", quoteIdentByType(dbType, dbName))
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
if strings.EqualFold(strings.TrimSpace(config.Database), dbName) {
|
||||
return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再删除"}
|
||||
}
|
||||
runConfig = config
|
||||
if strings.TrimSpace(runConfig.Database) == "" {
|
||||
runConfig.Database = "postgres"
|
||||
}
|
||||
sql = fmt.Sprintf("DROP DATABASE %s", quoteIdentByType(dbType, dbName))
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持删除数据库", dbType)}
|
||||
}
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "数据库删除成功"}
|
||||
}
|
||||
|
||||
func (a *App) RenameTable(config connection.ConnectionConfig, dbName string, oldTableName string, newTableName string) connection.QueryResult {
|
||||
oldTableName = strings.TrimSpace(oldTableName)
|
||||
newTableName = strings.TrimSpace(newTableName)
|
||||
if oldTableName == "" || newTableName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "表名不能为空"}
|
||||
}
|
||||
if strings.EqualFold(oldTableName, newTableName) {
|
||||
return connection.QueryResult{Success: false, Message: "新旧表名不能相同"}
|
||||
}
|
||||
if strings.Contains(newTableName, ".") {
|
||||
return connection.QueryResult{Success: false, Message: "新表名不能包含 schema 或数据库前缀"}
|
||||
}
|
||||
|
||||
dbType := resolveDDLDBType(config)
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx", "postgres", "kingbase", "sqlite", "oracle", "dameng", "highgo", "vastbase", "sqlserver":
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持重命名表", dbType)}
|
||||
}
|
||||
|
||||
schemaName, pureOldTableName := normalizeSchemaAndTableByType(dbType, dbName, oldTableName)
|
||||
if pureOldTableName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "旧表名不能为空"}
|
||||
}
|
||||
oldQualifiedTable := quoteTableIdentByType(dbType, schemaName, pureOldTableName)
|
||||
newTableQuoted := quoteIdentByType(dbType, newTableName)
|
||||
|
||||
var sql string
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx":
|
||||
newQualifiedTable := quoteTableIdentByType(dbType, schemaName, newTableName)
|
||||
sql = fmt.Sprintf("RENAME TABLE %s TO %s", oldQualifiedTable, newQualifiedTable)
|
||||
case "sqlserver":
|
||||
// SQL Server 使用 sp_rename,参数为 'schema.oldname', 'newname'
|
||||
oldFullName := schemaName + "." + pureOldTableName
|
||||
escapedOld := strings.ReplaceAll(oldFullName, "'", "''")
|
||||
escapedNew := strings.ReplaceAll(newTableName, "'", "''")
|
||||
sql = fmt.Sprintf("EXEC sp_rename '%s', '%s'", escapedOld, escapedNew)
|
||||
default:
|
||||
sql = fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldQualifiedTable, newTableQuoted)
|
||||
}
|
||||
|
||||
runConfig := buildRunConfigForDDL(config, dbType, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "表重命名成功"}
|
||||
}
|
||||
|
||||
func (a *App) DropTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
tableName = strings.TrimSpace(tableName)
|
||||
if tableName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "表名不能为空"}
|
||||
}
|
||||
|
||||
dbType := resolveDDLDBType(config)
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx", "postgres", "kingbase", "sqlite", "oracle", "dameng", "highgo", "vastbase", "sqlserver", "tdengine":
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持删除表", dbType)}
|
||||
}
|
||||
|
||||
schemaName, pureTableName := normalizeSchemaAndTableByType(dbType, dbName, tableName)
|
||||
if pureTableName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "表名不能为空"}
|
||||
}
|
||||
qualifiedTable := quoteTableIdentByType(dbType, schemaName, pureTableName)
|
||||
sql := fmt.Sprintf("DROP TABLE %s", qualifiedTable)
|
||||
|
||||
runConfig := buildRunConfigForDDL(config, dbType, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "表删除成功"}
|
||||
}
|
||||
|
||||
func (a *App) MySQLConnect(config connection.ConnectionConfig) connection.QueryResult {
|
||||
config.Type = "mysql"
|
||||
return a.DBConnect(config)
|
||||
@@ -103,7 +371,12 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
|
||||
defer cancel()
|
||||
|
||||
lowerQuery := strings.TrimSpace(strings.ToLower(query))
|
||||
if strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain") {
|
||||
isReadQuery := strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain")
|
||||
// MongoDB JSON 命令中的 find/count/aggregate 也属于读查询
|
||||
if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
|
||||
isReadQuery = true
|
||||
}
|
||||
if isReadQuery {
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
@@ -156,12 +429,12 @@ func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.Quer
|
||||
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
|
||||
var resData []map[string]string
|
||||
for _, name := range dbs {
|
||||
resData = append(resData, map[string]string{"Database": name})
|
||||
}
|
||||
|
||||
|
||||
return connection.QueryResult{Success: true, Data: resData}
|
||||
}
|
||||
|
||||
@@ -275,6 +548,125 @@ func (a *App) DBGetTriggers(config connection.ConnectionConfig, dbName string, t
|
||||
return connection.QueryResult{Success: true, Data: triggers}
|
||||
}
|
||||
|
||||
func (a *App) DropView(config connection.ConnectionConfig, dbName string, viewName string) connection.QueryResult {
|
||||
viewName = strings.TrimSpace(viewName)
|
||||
if viewName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "视图名称不能为空"}
|
||||
}
|
||||
|
||||
dbType := resolveDDLDBType(config)
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx", "postgres", "kingbase", "sqlite", "oracle", "dameng", "highgo", "vastbase", "sqlserver":
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持删除视图", dbType)}
|
||||
}
|
||||
|
||||
schemaName, pureViewName := normalizeSchemaAndTableByType(dbType, dbName, viewName)
|
||||
if pureViewName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "视图名称不能为空"}
|
||||
}
|
||||
qualifiedView := quoteTableIdentByType(dbType, schemaName, pureViewName)
|
||||
sql := fmt.Sprintf("DROP VIEW %s", qualifiedView)
|
||||
|
||||
runConfig := buildRunConfigForDDL(config, dbType, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "视图删除成功"}
|
||||
}
|
||||
|
||||
func (a *App) DropFunction(config connection.ConnectionConfig, dbName string, routineName string, routineType string) connection.QueryResult {
|
||||
routineName = strings.TrimSpace(routineName)
|
||||
routineType = strings.TrimSpace(strings.ToUpper(routineType))
|
||||
if routineName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "函数/存储过程名称不能为空"}
|
||||
}
|
||||
if routineType != "FUNCTION" && routineType != "PROCEDURE" {
|
||||
routineType = "FUNCTION"
|
||||
}
|
||||
|
||||
dbType := resolveDDLDBType(config)
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx", "postgres", "kingbase", "oracle", "dameng", "highgo", "vastbase", "sqlserver":
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持删除函数/存储过程", dbType)}
|
||||
}
|
||||
|
||||
schemaName, pureName := normalizeSchemaAndTableByType(dbType, dbName, routineName)
|
||||
if pureName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "函数/存储过程名称不能为空"}
|
||||
}
|
||||
qualifiedName := quoteTableIdentByType(dbType, schemaName, pureName)
|
||||
sql := fmt.Sprintf("DROP %s %s", routineType, qualifiedName)
|
||||
|
||||
runConfig := buildRunConfigForDDL(config, dbType, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
label := "函数"
|
||||
if routineType == "PROCEDURE" {
|
||||
label = "存储过程"
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: fmt.Sprintf("%s删除成功", label)}
|
||||
}
|
||||
|
||||
func (a *App) RenameView(config connection.ConnectionConfig, dbName string, oldName string, newName string) connection.QueryResult {
|
||||
oldName = strings.TrimSpace(oldName)
|
||||
newName = strings.TrimSpace(newName)
|
||||
if oldName == "" || newName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "视图名称不能为空"}
|
||||
}
|
||||
if strings.EqualFold(oldName, newName) {
|
||||
return connection.QueryResult{Success: false, Message: "新旧视图名称不能相同"}
|
||||
}
|
||||
if strings.Contains(newName, ".") {
|
||||
return connection.QueryResult{Success: false, Message: "新视图名不能包含 schema 或数据库前缀"}
|
||||
}
|
||||
|
||||
dbType := resolveDDLDBType(config)
|
||||
schemaName, pureOldName := normalizeSchemaAndTableByType(dbType, dbName, oldName)
|
||||
if pureOldName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "旧视图名不能为空"}
|
||||
}
|
||||
oldQualified := quoteTableIdentByType(dbType, schemaName, pureOldName)
|
||||
newQuoted := quoteIdentByType(dbType, newName)
|
||||
|
||||
var sql string
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "sphinx":
|
||||
newQualified := quoteTableIdentByType(dbType, schemaName, newName)
|
||||
sql = fmt.Sprintf("RENAME TABLE %s TO %s", oldQualified, newQualified)
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
sql = fmt.Sprintf("ALTER VIEW %s RENAME TO %s", oldQualified, newQuoted)
|
||||
case "sqlserver":
|
||||
oldFullName := schemaName + "." + pureOldName
|
||||
escapedOld := strings.ReplaceAll(oldFullName, "'", "''")
|
||||
escapedNew := strings.ReplaceAll(newName, "'", "''")
|
||||
sql = fmt.Sprintf("EXEC sp_rename '%s', '%s'", escapedOld, escapedNew)
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持重命名视图", dbType)}
|
||||
}
|
||||
|
||||
runConfig := buildRunConfigForDDL(config, dbType, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if _, err := dbInst.Exec(sql); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "视图重命名成功"}
|
||||
}
|
||||
|
||||
func (a *App) DBGetAllColumns(config connection.ConnectionConfig, dbName string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@ import (
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
|
||||
"github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
"github.com/xuri/excelize/v2"
|
||||
)
|
||||
|
||||
func (a *App) OpenSQLFile() connection.QueryResult {
|
||||
@@ -77,13 +77,40 @@ func (a *App) ImportConfigFile() connection.QueryResult {
|
||||
return connection.QueryResult{Success: true, Data: string(content)}
|
||||
}
|
||||
|
||||
// PreviewImportFile 解析导入文件,返回字段列表、总行数、前 5 行预览数据
|
||||
func (a *App) PreviewImportFile(filePath string) connection.QueryResult {
|
||||
if filePath == "" {
|
||||
return connection.QueryResult{Success: false, Message: "File path required"}
|
||||
}
|
||||
|
||||
rows, columns, err := parseImportFile(filePath)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
totalRows := len(rows)
|
||||
previewRows := rows
|
||||
if len(rows) > 5 {
|
||||
previewRows = rows[:5]
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"columns": columns,
|
||||
"totalRows": totalRows,
|
||||
"previewRows": previewRows,
|
||||
"filePath": filePath,
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: result}
|
||||
}
|
||||
|
||||
func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName string) connection.QueryResult {
|
||||
selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{
|
||||
Title: fmt.Sprintf("Import into %s", tableName),
|
||||
Filters: []runtime.FileFilter{
|
||||
{
|
||||
DisplayName: "Data Files",
|
||||
Pattern: "*.csv;*.json",
|
||||
Pattern: "*.csv;*.json;*.xlsx;*.xls",
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -96,44 +123,249 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
}
|
||||
|
||||
f, err := os.Open(selection)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
// 返回文件路径供前端预览
|
||||
return connection.QueryResult{Success: true, Data: map[string]interface{}{"filePath": selection}}
|
||||
}
|
||||
|
||||
var rows []map[string]interface{ }
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(selection), ".json") {
|
||||
// parseImportFile 解析导入文件,返回数据行和列名
|
||||
func parseImportFile(filePath string) ([]map[string]interface{}, []string, error) {
|
||||
var rows []map[string]interface{}
|
||||
var columns []string
|
||||
lower := strings.ToLower(filePath)
|
||||
|
||||
if strings.HasSuffix(lower, ".json") {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
decoder := json.NewDecoder(f)
|
||||
if err := decoder.Decode(&rows); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "JSON Parse Error: " + err.Error()}
|
||||
return nil, nil, fmt.Errorf("JSON Parse Error: %w", err)
|
||||
}
|
||||
} else if strings.HasSuffix(strings.ToLower(selection), ".csv") {
|
||||
if len(rows) > 0 {
|
||||
for k := range rows[0] {
|
||||
columns = append(columns, k)
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(lower, ".csv") {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
reader := csv.NewReader(f)
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "CSV Parse Error: " + err.Error()}
|
||||
return nil, nil, fmt.Errorf("CSV Parse Error: %w", err)
|
||||
}
|
||||
if len(records) < 2 {
|
||||
return connection.QueryResult{Success: false, Message: "CSV empty or missing header"}
|
||||
return nil, nil, fmt.Errorf("CSV empty or missing header")
|
||||
}
|
||||
headers := records[0]
|
||||
columns = records[0]
|
||||
for _, record := range records[1:] {
|
||||
row := make(map[string]interface{ })
|
||||
row := make(map[string]interface{})
|
||||
for i, val := range record {
|
||||
if i < len(headers) {
|
||||
if i < len(columns) {
|
||||
if val == "NULL" {
|
||||
row[headers[i]] = nil
|
||||
row[columns[i]] = nil
|
||||
} else {
|
||||
row[headers[i]] = val
|
||||
row[columns[i]] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
} else if strings.HasSuffix(lower, ".xlsx") || strings.HasSuffix(lower, ".xls") {
|
||||
xlsx, err := excelize.OpenFile(filePath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Excel Parse Error: %w", err)
|
||||
}
|
||||
defer xlsx.Close()
|
||||
|
||||
sheetName := xlsx.GetSheetName(0)
|
||||
if sheetName == "" {
|
||||
return nil, nil, fmt.Errorf("Excel file has no sheets")
|
||||
}
|
||||
|
||||
xlRows, err := xlsx.GetRows(sheetName)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Excel Read Error: %w", err)
|
||||
}
|
||||
if len(xlRows) < 2 {
|
||||
return nil, nil, fmt.Errorf("Excel empty or missing header")
|
||||
}
|
||||
|
||||
columns = xlRows[0]
|
||||
for _, record := range xlRows[1:] {
|
||||
row := make(map[string]interface{})
|
||||
for i, val := range record {
|
||||
if i < len(columns) && columns[i] != "" {
|
||||
if val == "NULL" {
|
||||
row[columns[i]] = nil
|
||||
} else {
|
||||
row[columns[i]] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(row) > 0 {
|
||||
rows = append(rows, row)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return connection.QueryResult{Success: false, Message: "Unsupported file format"}
|
||||
return nil, nil, fmt.Errorf("Unsupported file format")
|
||||
}
|
||||
|
||||
return rows, columns, nil
|
||||
}
|
||||
|
||||
func normalizeColumnName(name string) string {
|
||||
return strings.ToLower(strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
func buildImportColumnTypeMap(defs []connection.ColumnDefinition) map[string]string {
|
||||
result := make(map[string]string, len(defs))
|
||||
for _, def := range defs {
|
||||
key := normalizeColumnName(def.Name)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = strings.TrimSpace(def.Type)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func isTimezoneAwareColumnType(columnType string) bool {
|
||||
typ := strings.ToLower(strings.TrimSpace(columnType))
|
||||
if typ == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(typ, "with time zone") ||
|
||||
strings.Contains(typ, "with timezone") ||
|
||||
strings.Contains(typ, "datetimeoffset") ||
|
||||
strings.Contains(typ, "timestamptz")
|
||||
}
|
||||
|
||||
func isDateTimeColumnType(columnType string) bool {
|
||||
typ := strings.ToLower(strings.TrimSpace(columnType))
|
||||
if typ == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(typ, "datetime") || strings.Contains(typ, "timestamp")
|
||||
}
|
||||
|
||||
func isTimeOnlyColumnType(columnType string) bool {
|
||||
typ := strings.ToLower(strings.TrimSpace(columnType))
|
||||
if typ == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(typ, "datetime") || strings.Contains(typ, "timestamp") {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(typ, "time")
|
||||
}
|
||||
|
||||
func isDateOnlyColumnType(dbType, columnType string) bool {
|
||||
typ := strings.ToLower(strings.TrimSpace(columnType))
|
||||
if typ == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(typ, "datetime") || strings.Contains(typ, "timestamp") || strings.Contains(typ, "time") {
|
||||
return false
|
||||
}
|
||||
if !strings.Contains(typ, "date") {
|
||||
return false
|
||||
}
|
||||
db := strings.ToLower(strings.TrimSpace(dbType))
|
||||
// Oracle/Dameng 的 DATE 带时间语义,不能按纯日期裁剪。
|
||||
return db != "oracle" && db != "dameng"
|
||||
}
|
||||
|
||||
func isTemporalColumnType(dbType, columnType string) bool {
|
||||
return isDateTimeColumnType(columnType) || isTimeOnlyColumnType(columnType) || isDateOnlyColumnType(dbType, columnType)
|
||||
}
|
||||
|
||||
func parseTemporalString(raw string) (time.Time, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
layouts := []string{
|
||||
"2006-01-02 15:04:05.999999999 -0700 MST",
|
||||
"2006-01-02 15:04:05 -0700 MST",
|
||||
"2006-01-02 15:04:05.999999999 -0700",
|
||||
"2006-01-02 15:04:05 -0700",
|
||||
time.RFC3339Nano,
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02",
|
||||
"15:04:05.999999999",
|
||||
"15:04:05",
|
||||
}
|
||||
|
||||
for _, layout := range layouts {
|
||||
parsed, err := time.Parse(layout, text)
|
||||
if err == nil {
|
||||
return parsed, true
|
||||
}
|
||||
}
|
||||
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func normalizeImportTemporalValue(dbType, columnType, raw string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
parsed, ok := parseTemporalString(text)
|
||||
if !ok {
|
||||
if isDateTimeColumnType(columnType) {
|
||||
candidate := strings.ReplaceAll(text, "T", " ")
|
||||
if len(candidate) >= 19 {
|
||||
prefix := candidate[:19]
|
||||
if _, err := time.Parse("2006-01-02 15:04:05", prefix); err == nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
if isTimeOnlyColumnType(columnType) {
|
||||
return parsed.Format("15:04:05")
|
||||
}
|
||||
if isDateOnlyColumnType(dbType, columnType) {
|
||||
return parsed.Format("2006-01-02")
|
||||
}
|
||||
if isTimezoneAwareColumnType(columnType) {
|
||||
return parsed.Format("2006-01-02 15:04:05-07:00")
|
||||
}
|
||||
return parsed.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func formatImportSQLValue(dbType, columnType string, value interface{}) string {
|
||||
if value == nil {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
if isTemporalColumnType(dbType, columnType) {
|
||||
normalized := normalizeImportTemporalValue(dbType, columnType, fmt.Sprintf("%v", value))
|
||||
escaped := strings.ReplaceAll(normalized, "'", "''")
|
||||
return "'" + escaped + "'"
|
||||
}
|
||||
|
||||
return formatSQLValue(dbType, value)
|
||||
}
|
||||
|
||||
// ImportDataWithProgress 执行导入并发送进度事件
|
||||
func (a *App) ImportDataWithProgress(config connection.ConnectionConfig, dbName, tableName, filePath string) connection.QueryResult {
|
||||
rows, columns, err := parseImportFile(filePath)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
@@ -146,29 +378,27 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
successCount := 0
|
||||
errCount := 0
|
||||
firstRow := rows[0]
|
||||
var cols []string
|
||||
for k := range firstRow {
|
||||
cols = append(cols, k)
|
||||
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
|
||||
columnTypeMap := map[string]string{}
|
||||
if defs, colErr := dbInst.GetColumns(schemaName, pureTableName); colErr == nil {
|
||||
columnTypeMap = buildImportColumnTypeMap(defs)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
|
||||
totalRows := len(rows)
|
||||
successCount := 0
|
||||
var errorLogs []string
|
||||
|
||||
quotedCols := make([]string, len(columns))
|
||||
for i, c := range columns {
|
||||
quotedCols[i] = quoteIdentByType(runConfig.Type, c)
|
||||
}
|
||||
|
||||
for idx, row := range rows {
|
||||
var values []string
|
||||
for _, col := range cols {
|
||||
for _, col := range columns {
|
||||
val := row[col]
|
||||
if val == nil {
|
||||
values = append(values, "NULL")
|
||||
} else {
|
||||
vStr := fmt.Sprintf("%v", val)
|
||||
vStr = strings.ReplaceAll(vStr, "'", "''")
|
||||
values = append(values, fmt.Sprintf("'%s'", vStr))
|
||||
}
|
||||
}
|
||||
quotedCols := make([]string, len(cols))
|
||||
for i, c := range cols {
|
||||
quotedCols[i] = quoteIdentByType(runConfig.Type, c)
|
||||
colType := columnTypeMap[normalizeColumnName(col)]
|
||||
values = append(values, formatImportSQLValue(runConfig.Type, colType, val))
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
@@ -178,14 +408,31 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
|
||||
|
||||
_, err := dbInst.Exec(query)
|
||||
if err != nil {
|
||||
errCount++
|
||||
logger.Error(err, "导入数据失败:表=%s", tableName)
|
||||
errorLogs = append(errorLogs, fmt.Sprintf("Row %d: %s", idx+1, err.Error()))
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
|
||||
// 每 10 行发送一次进度事件
|
||||
if (idx+1)%10 == 0 || idx == totalRows-1 {
|
||||
runtime.EventsEmit(a.ctx, "import:progress", map[string]interface{}{
|
||||
"current": idx + 1,
|
||||
"total": totalRows,
|
||||
"success": successCount,
|
||||
"errors": len(errorLogs),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: fmt.Sprintf("Imported: %d, Failed: %d", successCount, errCount)}
|
||||
result := map[string]interface{}{
|
||||
"success": successCount,
|
||||
"failed": len(errorLogs),
|
||||
"total": totalRows,
|
||||
"errorLogs": errorLogs,
|
||||
"errorSummary": fmt.Sprintf("Imported: %d, Failed: %d", successCount, len(errorLogs)),
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: result, Message: fmt.Sprintf("Imported: %d, Failed: %d", successCount, len(errorLogs))}
|
||||
}
|
||||
|
||||
func (a *App) ApplyChanges(config connection.ConnectionConfig, dbName, tableName string, changes connection.ChangeSet) connection.QueryResult {
|
||||
@@ -195,7 +442,7 @@ func (a *App) ApplyChanges(config connection.ConnectionConfig, dbName, tableName
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
|
||||
if applier, ok := dbInst.(db.BatchApplier); ok {
|
||||
err := applier.ApplyChanges(tableName, changes)
|
||||
if err != nil {
|
||||
@@ -219,7 +466,7 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
@@ -238,7 +485,7 @@ dbInst, err := a.getDatabase(runConfig)
|
||||
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, tableName, true); err != nil {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, tableName, true, true); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if err := writeSQLFooter(w, runConfig); err != nil {
|
||||
@@ -249,8 +496,8 @@ dbInst, err := a.getDatabase(runConfig)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(runConfig.Type, tableName))
|
||||
|
||||
data, columns, err := dbInst.Query(query)
|
||||
|
||||
data, columns, err := dbInst.Query(query)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
@@ -268,13 +515,27 @@ data, columns, err := dbInst.Query(query)
|
||||
}
|
||||
|
||||
func (a *App) ExportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeData bool) connection.QueryResult {
|
||||
return a.exportTablesSQL(config, dbName, tableNames, true, includeData)
|
||||
}
|
||||
|
||||
func (a *App) ExportTablesDataSQL(config connection.ConnectionConfig, dbName string, tableNames []string) connection.QueryResult {
|
||||
return a.exportTablesSQL(config, dbName, tableNames, false, true)
|
||||
}
|
||||
|
||||
func (a *App) exportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeSchema bool, includeData bool) connection.QueryResult {
|
||||
if !includeSchema && !includeData {
|
||||
return connection.QueryResult{Success: false, Message: "invalid export mode"}
|
||||
}
|
||||
|
||||
safeDbName := strings.TrimSpace(dbName)
|
||||
if safeDbName == "" {
|
||||
safeDbName = "export"
|
||||
}
|
||||
suffix := "schema"
|
||||
if includeData {
|
||||
if includeSchema && includeData {
|
||||
suffix = "backup"
|
||||
} else if !includeSchema && includeData {
|
||||
suffix = "data"
|
||||
}
|
||||
defaultFilename := fmt.Sprintf("%s_%s_%dtables.sql", safeDbName, suffix, len(tableNames))
|
||||
if len(tableNames) == 1 && strings.TrimSpace(tableNames[0]) != "" {
|
||||
@@ -323,7 +584,7 @@ func (a *App) ExportTablesSQL(config connection.ConnectionConfig, dbName string,
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
for _, t := range tables {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeSchema, includeData); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
}
|
||||
@@ -377,7 +638,7 @@ func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName strin
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
for _, t := range tables {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, true, includeData); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
}
|
||||
@@ -394,8 +655,11 @@ func quoteIdentByType(dbType string, ident string) string {
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
case "mysql", "mariadb", "sphinx", "tdengine":
|
||||
return "`" + strings.ReplaceAll(ident, "`", "``") + "`"
|
||||
case "sqlserver":
|
||||
escaped := strings.ReplaceAll(ident, "]", "]]")
|
||||
return "[" + escaped + "]"
|
||||
default:
|
||||
return `"` + strings.ReplaceAll(ident, `"`, `""`) + `"`
|
||||
}
|
||||
@@ -534,7 +798,7 @@ func formatSQLValue(dbType string, v interface{}) string {
|
||||
}
|
||||
}
|
||||
|
||||
func dumpTableSQL(w *bufio.Writer, dbInst db.Database, config connection.ConnectionConfig, dbName, tableName string, includeData bool) error {
|
||||
func dumpTableSQL(w *bufio.Writer, dbInst db.Database, config connection.ConnectionConfig, dbName, tableName string, includeSchema bool, includeData bool) error {
|
||||
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
|
||||
|
||||
if _, err := w.WriteString("\n-- ----------------------------\n"); err != nil {
|
||||
@@ -547,15 +811,17 @@ func dumpTableSQL(w *bufio.Writer, dbInst db.Database, config connection.Connect
|
||||
return err
|
||||
}
|
||||
|
||||
createSQL, err := dbInst.GetCreateStatement(schemaName, pureTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString(ensureSQLTerminator(createSQL)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("\n\n"); err != nil {
|
||||
return err
|
||||
if includeSchema {
|
||||
createSQL, err := dbInst.GetCreateStatement(schemaName, pureTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString(ensureSQLTerminator(createSQL)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !includeData {
|
||||
@@ -676,12 +942,17 @@ func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string
|
||||
return fmt.Errorf("file required")
|
||||
}
|
||||
|
||||
// xlsx 使用 excelize 写入真正的 Excel 格式
|
||||
if format == "xlsx" {
|
||||
return writeRowsToXlsx(f.Name(), data, columns)
|
||||
}
|
||||
|
||||
var csvWriter *csv.Writer
|
||||
var jsonEncoder *json.Encoder
|
||||
isJsonFirstRow := true
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
case "csv":
|
||||
if _, err := f.Write([]byte{0xEF, 0xBB, 0xBF}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -719,7 +990,7 @@ func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string
|
||||
continue
|
||||
}
|
||||
|
||||
s := fmt.Sprintf("%v", val)
|
||||
s := formatExportCellText(val)
|
||||
if format == "md" {
|
||||
s = strings.ReplaceAll(s, "|", "\\|")
|
||||
s = strings.ReplaceAll(s, "\n", "<br>")
|
||||
@@ -728,7 +999,7 @@ func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
case "csv":
|
||||
if err := csvWriter.Write(record); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -749,7 +1020,7 @@ func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string
|
||||
}
|
||||
}
|
||||
|
||||
if format == "csv" || format == "xlsx" {
|
||||
if format == "csv" {
|
||||
csvWriter.Flush()
|
||||
if err := csvWriter.Error(); err != nil {
|
||||
return err
|
||||
@@ -764,3 +1035,50 @@ func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatExportCellText(val interface{}) string {
|
||||
if val == nil {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
switch v := val.(type) {
|
||||
case time.Time:
|
||||
return v.Format("2006-01-02 15:04:05")
|
||||
case *time.Time:
|
||||
if v == nil {
|
||||
return "NULL"
|
||||
}
|
||||
return v.Format("2006-01-02 15:04:05")
|
||||
default:
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// writeRowsToXlsx 使用 excelize 写入真正的 xlsx 格式文件
|
||||
func writeRowsToXlsx(filename string, data []map[string]interface{}, columns []string) error {
|
||||
xlsx := excelize.NewFile()
|
||||
defer xlsx.Close()
|
||||
|
||||
sheet := "Sheet1"
|
||||
|
||||
// 写入表头
|
||||
for i, col := range columns {
|
||||
cell, _ := excelize.CoordinatesToCellName(i+1, 1)
|
||||
xlsx.SetCellValue(sheet, cell, col)
|
||||
}
|
||||
|
||||
// 写入数据行
|
||||
for rowIdx, rowMap := range data {
|
||||
for colIdx, col := range columns {
|
||||
cell, _ := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2)
|
||||
val := rowMap[col]
|
||||
if val == nil {
|
||||
xlsx.SetCellValue(sheet, cell, "NULL")
|
||||
} else {
|
||||
xlsx.SetCellValue(sheet, cell, formatExportCellText(val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return xlsx.SaveAs(filename)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -22,9 +23,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
updateRepo = "Syngnat/GoNavi"
|
||||
updateAPIURL = "https://api.github.com/repos/" + updateRepo + "/releases/latest"
|
||||
updateChecksumAsset = "SHA256SUMS"
|
||||
updateRepo = "Syngnat/GoNavi"
|
||||
updateAPIURL = "https://api.github.com/repos/" + updateRepo + "/releases/latest"
|
||||
updateChecksumAsset = "SHA256SUMS"
|
||||
updateDownloadProgressEvent = "update:download-progress"
|
||||
)
|
||||
|
||||
type updateState struct {
|
||||
@@ -54,11 +56,29 @@ type AppInfo struct {
|
||||
BuildTime string `json:"buildTime,omitempty"`
|
||||
}
|
||||
|
||||
type updateDownloadResult struct {
|
||||
Info UpdateInfo `json:"info"`
|
||||
DownloadPath string `json:"downloadPath,omitempty"`
|
||||
InstallLogPath string `json:"installLogPath,omitempty"`
|
||||
InstallTarget string `json:"installTarget,omitempty"`
|
||||
Platform string `json:"platform"`
|
||||
AutoRelaunch bool `json:"autoRelaunch"`
|
||||
}
|
||||
|
||||
type updateDownloadProgressPayload struct {
|
||||
Status string `json:"status"`
|
||||
Percent float64 `json:"percent"`
|
||||
Downloaded int64 `json:"downloaded"`
|
||||
Total int64 `json:"total"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type stagedUpdate struct {
|
||||
Version string
|
||||
AssetName string
|
||||
FilePath string
|
||||
StagedDir string
|
||||
Version string
|
||||
AssetName string
|
||||
FilePath string
|
||||
StagedDir string
|
||||
InstallLogPath string
|
||||
}
|
||||
|
||||
type githubRelease struct {
|
||||
@@ -124,13 +144,15 @@ func (a *App) DownloadUpdate() connection.QueryResult {
|
||||
a.updateMu.Unlock()
|
||||
return connection.QueryResult{Success: false, Message: "未找到可用的更新包"}
|
||||
}
|
||||
if a.updateState.staged != nil && a.updateState.staged.Version == info.LatestVersion {
|
||||
staged := a.updateState.staged
|
||||
if staged != nil && staged.Version == info.LatestVersion {
|
||||
a.updateMu.Unlock()
|
||||
return connection.QueryResult{Success: true, Message: "更新包已下载完成", Data: info}
|
||||
return connection.QueryResult{Success: true, Message: "更新包已下载完成", Data: buildUpdateDownloadResult(*info, staged)}
|
||||
}
|
||||
a.updateState.downloading = true
|
||||
a.updateMu.Unlock()
|
||||
|
||||
a.emitUpdateDownloadProgress("start", 0, info.AssetSize, "")
|
||||
result := a.downloadAndStageUpdate(*info)
|
||||
|
||||
a.updateMu.Lock()
|
||||
@@ -143,6 +165,9 @@ func (a *App) DownloadUpdate() connection.QueryResult {
|
||||
func (a *App) InstallUpdateAndRestart() connection.QueryResult {
|
||||
a.updateMu.Lock()
|
||||
staged := a.updateState.staged
|
||||
if staged != nil && strings.TrimSpace(staged.InstallLogPath) == "" {
|
||||
staged.InstallLogPath = buildUpdateInstallLogPath(filepath.Dir(staged.FilePath))
|
||||
}
|
||||
a.updateMu.Unlock()
|
||||
if staged == nil {
|
||||
return connection.QueryResult{Success: false, Message: "未找到已下载的更新包"}
|
||||
@@ -150,7 +175,17 @@ func (a *App) InstallUpdateAndRestart() connection.QueryResult {
|
||||
|
||||
if err := launchUpdateScript(staged); err != nil {
|
||||
logger.Error(err, "启动更新脚本失败")
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
msg := err.Error()
|
||||
if staged.InstallLogPath != "" {
|
||||
msg = fmt.Sprintf("%s(更新日志:%s)", msg, staged.InstallLogPath)
|
||||
}
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
Data: map[string]any{
|
||||
"logPath": staged.InstallLogPath,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -161,41 +196,95 @@ func (a *App) InstallUpdateAndRestart() connection.QueryResult {
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "更新已开始安装"}
|
||||
msg := "更新已开始安装"
|
||||
if staged.InstallLogPath != "" {
|
||||
msg = fmt.Sprintf("更新已开始安装,日志路径:%s", staged.InstallLogPath)
|
||||
}
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Data: map[string]any{
|
||||
"logPath": staged.InstallLogPath,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) downloadAndStageUpdate(info UpdateInfo) connection.QueryResult {
|
||||
stagedDir, err := os.MkdirTemp("", "gonavi-update-")
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "创建临时目录失败"}
|
||||
workspaceDir := strings.TrimSpace(resolveUpdateWorkspaceDir())
|
||||
if workspaceDir == "" {
|
||||
a.emitUpdateDownloadProgress("error", 0, info.AssetSize, "无法确定当前应用目录")
|
||||
return connection.QueryResult{Success: false, Message: "无法确定当前应用目录,无法下载更新"}
|
||||
}
|
||||
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
|
||||
errMsg := fmt.Sprintf("无法访问应用目录:%s", workspaceDir)
|
||||
a.emitUpdateDownloadProgress("error", 0, info.AssetSize, errMsg)
|
||||
return connection.QueryResult{Success: false, Message: errMsg}
|
||||
}
|
||||
|
||||
// 使用版本号命名的工作目录,便于识别和调试
|
||||
stagedDir := filepath.Join(workspaceDir, fmt.Sprintf(".gonavi-update-%s-%s", stdRuntime.GOOS, info.LatestVersion))
|
||||
// 清理可能残留的旧目录(上次下载失败后未清理)
|
||||
// Windows 上文件可能被杀毒软件/索引服务占用,需要重试
|
||||
for retry := 0; retry < 5; retry++ {
|
||||
err := os.RemoveAll(stagedDir)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if retry < 4 {
|
||||
time.Sleep(time.Duration(retry+1) * 500 * time.Millisecond)
|
||||
} else {
|
||||
// 最后一次仍然失败,换一个带时间戳的目录名避免冲突
|
||||
stagedDir = filepath.Join(workspaceDir, fmt.Sprintf(".gonavi-update-%s-%s-%d", stdRuntime.GOOS, info.LatestVersion, time.Now().UnixNano()))
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(stagedDir, 0o755); err != nil {
|
||||
errMsg := fmt.Sprintf("无法在应用目录创建更新工作目录:%s", stagedDir)
|
||||
a.emitUpdateDownloadProgress("error", 0, info.AssetSize, errMsg)
|
||||
return connection.QueryResult{Success: false, Message: errMsg}
|
||||
}
|
||||
|
||||
// 下载到 staging 目录,避免覆盖正在运行的可执行文件
|
||||
assetPath := filepath.Join(stagedDir, info.AssetName)
|
||||
actualHash, err := downloadFileWithHash(info.AssetURL, assetPath)
|
||||
actualHash, err := downloadFileWithHash(info.AssetURL, assetPath, func(downloaded, total int64) {
|
||||
reportTotal := total
|
||||
if reportTotal <= 0 {
|
||||
reportTotal = info.AssetSize
|
||||
}
|
||||
a.emitUpdateDownloadProgress("downloading", downloaded, reportTotal, "")
|
||||
})
|
||||
if err != nil {
|
||||
_ = os.Remove(assetPath)
|
||||
_ = os.RemoveAll(stagedDir)
|
||||
a.emitUpdateDownloadProgress("error", 0, info.AssetSize, err.Error())
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
if info.SHA256 == "" {
|
||||
_ = os.Remove(assetPath)
|
||||
_ = os.RemoveAll(stagedDir)
|
||||
a.emitUpdateDownloadProgress("error", 0, info.AssetSize, "缺少更新包校验值(SHA256SUMS)")
|
||||
return connection.QueryResult{Success: false, Message: "缺少更新包校验值(SHA256SUMS)"}
|
||||
}
|
||||
if !strings.EqualFold(info.SHA256, actualHash) {
|
||||
_ = os.Remove(assetPath)
|
||||
_ = os.RemoveAll(stagedDir)
|
||||
a.emitUpdateDownloadProgress("error", 0, info.AssetSize, "更新包校验失败,请重试")
|
||||
return connection.QueryResult{Success: false, Message: "更新包校验失败,请重试"}
|
||||
}
|
||||
|
||||
a.updateMu.Lock()
|
||||
a.updateState.staged = &stagedUpdate{
|
||||
Version: info.LatestVersion,
|
||||
AssetName: info.AssetName,
|
||||
FilePath: assetPath,
|
||||
StagedDir: stagedDir,
|
||||
staged := &stagedUpdate{
|
||||
Version: info.LatestVersion,
|
||||
AssetName: info.AssetName,
|
||||
FilePath: assetPath,
|
||||
StagedDir: stagedDir,
|
||||
InstallLogPath: buildUpdateInstallLogPath(workspaceDir),
|
||||
}
|
||||
a.updateMu.Lock()
|
||||
a.updateState.staged = staged
|
||||
a.updateMu.Unlock()
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "更新包下载完成", Data: info}
|
||||
a.emitUpdateDownloadProgress("done", info.AssetSize, info.AssetSize, "")
|
||||
return connection.QueryResult{Success: true, Message: "更新包下载完成", Data: buildUpdateDownloadResult(info, staged)}
|
||||
}
|
||||
|
||||
func fetchLatestUpdateInfo() (UpdateInfo, error) {
|
||||
@@ -370,7 +459,32 @@ func parseSHA256Sums(content string) map[string]string {
|
||||
return result
|
||||
}
|
||||
|
||||
func downloadFileWithHash(url, filePath string) (string, error) {
|
||||
type downloadProgressWriter struct {
|
||||
total int64
|
||||
written int64
|
||||
lastEmit time.Time
|
||||
emitEvery time.Duration
|
||||
onProgress func(downloaded, total int64)
|
||||
}
|
||||
|
||||
func (w *downloadProgressWriter) Write(p []byte) (int, error) {
|
||||
n := len(p)
|
||||
if n == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
w.written += int64(n)
|
||||
if w.onProgress == nil {
|
||||
return n, nil
|
||||
}
|
||||
now := time.Now()
|
||||
if w.lastEmit.IsZero() || now.Sub(w.lastEmit) >= w.emitEvery || (w.total > 0 && w.written >= w.total) {
|
||||
w.lastEmit = now
|
||||
w.onProgress(w.written, w.total)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func downloadFileWithHash(url, filePath string, onProgress func(downloaded, total int64)) (string, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
@@ -388,21 +502,121 @@ func downloadFileWithHash(url, filePath string) (string, error) {
|
||||
return "", fmt.Errorf("下载更新包失败:HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
out, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
// Windows 上旧文件可能被杀毒软件/索引服务占用,先尝试删除并重试
|
||||
_ = os.Remove(filePath)
|
||||
var out *os.File
|
||||
for retry := 0; retry < 5; retry++ {
|
||||
out, err = os.Create(filePath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if retry < 4 {
|
||||
time.Sleep(time.Duration(retry+1) * 500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("更新下载失败,文件被占用:%w", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
writer := io.MultiWriter(out, hasher)
|
||||
if _, err := io.Copy(writer, resp.Body); err != nil {
|
||||
total := resp.ContentLength
|
||||
progressWriter := &downloadProgressWriter{
|
||||
total: total,
|
||||
emitEvery: 120 * time.Millisecond,
|
||||
onProgress: onProgress,
|
||||
}
|
||||
writers := []io.Writer{out, hasher, progressWriter}
|
||||
if onProgress != nil {
|
||||
onProgress(0, total)
|
||||
}
|
||||
if _, err := io.Copy(io.MultiWriter(writers...), resp.Body); err != nil {
|
||||
out.Close()
|
||||
return "", err
|
||||
}
|
||||
if onProgress != nil {
|
||||
onProgress(progressWriter.written, total)
|
||||
}
|
||||
|
||||
// 显式 Sync + Close,确保数据落盘且文件句柄释放
|
||||
if err := out.Sync(); err != nil {
|
||||
out.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := out.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func buildUpdateDownloadResult(info UpdateInfo, staged *stagedUpdate) updateDownloadResult {
|
||||
result := updateDownloadResult{
|
||||
Info: info,
|
||||
Platform: stdRuntime.GOOS,
|
||||
InstallTarget: resolveUpdateInstallTarget(),
|
||||
AutoRelaunch: true,
|
||||
}
|
||||
if staged != nil {
|
||||
result.DownloadPath = staged.FilePath
|
||||
result.InstallLogPath = staged.InstallLogPath
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildUpdateInstallLogPath(baseDir string) string {
|
||||
platform := stdRuntime.GOOS
|
||||
if platform == "darwin" {
|
||||
platform = "macos"
|
||||
}
|
||||
logDir := strings.TrimSpace(baseDir)
|
||||
if logDir == "" {
|
||||
logDir = os.TempDir()
|
||||
}
|
||||
return filepath.Join(logDir, fmt.Sprintf("gonavi-update-%s-%d.log", platform, time.Now().UnixNano()))
|
||||
}
|
||||
|
||||
func resolveUpdateWorkspaceDir() string {
|
||||
// 使用系统临时目录作为更新工作区,避免以下问题:
|
||||
// 1. Windows: exe 所在目录可能被杀毒软件/索引服务锁定,或缺少写权限(如 Program Files)
|
||||
// 2. macOS: /Applications 需要管理员权限才能写入
|
||||
// 3. 运行中的 exe 文件锁与 staging 文件冲突
|
||||
dir := filepath.Join(os.TempDir(), "gonavi-updates")
|
||||
_ = os.MkdirAll(dir, 0o755)
|
||||
return dir
|
||||
}
|
||||
|
||||
func resolveUpdateInstallTarget() string {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
exePath, _ = filepath.EvalSymlinks(exePath)
|
||||
if stdRuntime.GOOS == "darwin" {
|
||||
return resolveMacUpdateTarget(exePath)
|
||||
}
|
||||
return exePath
|
||||
}
|
||||
|
||||
func (a *App) emitUpdateDownloadProgress(status string, downloaded, total int64, message string) {
|
||||
if a.ctx == nil {
|
||||
return
|
||||
}
|
||||
payload := updateDownloadProgressPayload{
|
||||
Status: status,
|
||||
Percent: 0,
|
||||
Downloaded: downloaded,
|
||||
Total: total,
|
||||
Message: strings.TrimSpace(message),
|
||||
}
|
||||
if total > 0 {
|
||||
payload.Percent = math.Min(100, (float64(downloaded)/float64(total))*100)
|
||||
}
|
||||
if status == "done" && payload.Percent < 100 {
|
||||
payload.Percent = 100
|
||||
}
|
||||
wailsRuntime.EventsEmit(a.ctx, updateDownloadProgressEvent, payload)
|
||||
}
|
||||
|
||||
func launchUpdateScript(staged *stagedUpdate) error {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
@@ -425,7 +639,11 @@ func launchUpdateScript(staged *stagedUpdate) error {
|
||||
|
||||
func launchWindowsUpdate(staged *stagedUpdate, targetExe string, pid int) error {
|
||||
scriptPath := filepath.Join(staged.StagedDir, "update.cmd")
|
||||
logPath := filepath.Join(staged.StagedDir, "update.log")
|
||||
logPath := strings.TrimSpace(staged.InstallLogPath)
|
||||
if logPath == "" {
|
||||
logPath = buildUpdateInstallLogPath(filepath.Dir(staged.FilePath))
|
||||
staged.InstallLogPath = logPath
|
||||
}
|
||||
content := buildWindowsScript(staged.FilePath, targetExe, staged.StagedDir, logPath, pid)
|
||||
if err := os.WriteFile(scriptPath, []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
@@ -442,7 +660,11 @@ func launchMacUpdate(staged *stagedUpdate, targetExe string, pid int) error {
|
||||
if err := os.MkdirAll(mountDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
logPath := filepath.Join(staged.StagedDir, "update.log")
|
||||
logPath := strings.TrimSpace(staged.InstallLogPath)
|
||||
if logPath == "" {
|
||||
logPath = buildUpdateInstallLogPath(filepath.Dir(staged.FilePath))
|
||||
staged.InstallLogPath = logPath
|
||||
}
|
||||
|
||||
scriptPath := filepath.Join(staged.StagedDir, "update.sh")
|
||||
content := buildMacScript(staged.FilePath, targetApp, staged.StagedDir, mountDir, logPath, pid)
|
||||
@@ -509,8 +731,12 @@ exit /b 1
|
||||
:move_done
|
||||
start "" "%%TARGET%%" >> "%%LOG_FILE%%" 2>&1
|
||||
if %%ERRORLEVEL%% NEQ 0 (
|
||||
call :log relaunch failed
|
||||
exit /b 1
|
||||
call :log cmd start failed, trying powershell Start-Process
|
||||
powershell -NoProfile -ExecutionPolicy Bypass -Command "Start-Process -FilePath '%%TARGET%%'" >> "%%LOG_FILE%%" 2>&1
|
||||
if %%ERRORLEVEL%% NEQ 0 (
|
||||
call :log relaunch failed
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
rmdir /S /Q "%%STAGED%%" >> "%%LOG_FILE%%" 2>&1
|
||||
call :log update finished
|
||||
@@ -531,30 +757,69 @@ TARGET_APP="%s"
|
||||
STAGED="%s"
|
||||
MOUNT_DIR="%s"
|
||||
LOG_FILE="%s"
|
||||
TMP_APP="${TARGET_APP}.new"
|
||||
BACKUP_APP="${TARGET_APP}.backup"
|
||||
APP_BIN_NAME=$(basename "$TARGET_APP" .app)
|
||||
APP_BIN_REL="Contents/MacOS/$APP_BIN_NAME"
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%%Y-%%m-%%d %%H:%%M:%%S')] $*" >> "$LOG_FILE"
|
||||
}
|
||||
|
||||
run_admin_install() {
|
||||
/usr/bin/osascript <<'APPLESCRIPT' "$APP_SRC" "$TARGET_APP" "$LOG_FILE"
|
||||
run_admin_replace() {
|
||||
/usr/bin/osascript <<'APPLESCRIPT' "$APP_SRC" "$TARGET_APP" "$TMP_APP" "$BACKUP_APP" "$APP_BIN_REL" "$LOG_FILE"
|
||||
on run argv
|
||||
set srcPath to item 1 of argv
|
||||
set dstPath to item 2 of argv
|
||||
set logPath to item 3 of argv
|
||||
do shell script "rm -rf " & quoted form of dstPath & " && cp -R " & quoted form of srcPath & " " & quoted form of dstPath & " >> " & quoted form of logPath & " 2>&1" with administrator privileges
|
||||
set tmpPath to item 3 of argv
|
||||
set bakPath to item 4 of argv
|
||||
set binRel to item 5 of argv
|
||||
set logPath to item 6 of argv
|
||||
set cmd to "set -eu; " & ¬
|
||||
"rm -rf " & quoted form of tmpPath & " " & quoted form of bakPath & "; " & ¬
|
||||
"/usr/bin/ditto " & quoted form of srcPath & " " & quoted form of tmpPath & "; " & ¬
|
||||
"if [ ! -x " & quoted form of (tmpPath & "/" & binRel) & " ]; then echo 'tmp app binary missing' >> " & quoted form of logPath & "; exit 1; fi; " & ¬
|
||||
"xattr -rd com.apple.quarantine " & quoted form of tmpPath & " >> " & quoted form of logPath & " 2>&1 || true; " & ¬
|
||||
"if [ -d " & quoted form of dstPath & " ]; then mv " & quoted form of dstPath & " " & quoted form of bakPath & "; fi; " & ¬
|
||||
"mv " & quoted form of tmpPath & " " & quoted form of dstPath & "; " & ¬
|
||||
"rm -rf " & quoted form of bakPath & "; " & ¬
|
||||
"xattr -rd com.apple.quarantine " & quoted form of dstPath & " >> " & quoted form of logPath & " 2>&1 || true"
|
||||
do shell script cmd with administrator privileges
|
||||
end run
|
||||
APPLESCRIPT
|
||||
}
|
||||
|
||||
run_admin_xattr() {
|
||||
/usr/bin/osascript <<'APPLESCRIPT' "$TARGET_APP" "$LOG_FILE"
|
||||
on run argv
|
||||
set dstPath to item 1 of argv
|
||||
set logPath to item 2 of argv
|
||||
do shell script "xattr -rd com.apple.quarantine " & quoted form of dstPath & " >> " & quoted form of logPath & " 2>&1" with administrator privileges
|
||||
end run
|
||||
APPLESCRIPT
|
||||
replace_app_direct() {
|
||||
rm -rf "$TMP_APP" "$BACKUP_APP" >>"$LOG_FILE" 2>&1 || true
|
||||
/usr/bin/ditto "$APP_SRC" "$TMP_APP" >>"$LOG_FILE" 2>&1
|
||||
if [ ! -x "$TMP_APP/$APP_BIN_REL" ]; then
|
||||
log "tmp app binary missing: $TMP_APP/$APP_BIN_REL"
|
||||
return 1
|
||||
fi
|
||||
xattr -rd com.apple.quarantine "$TMP_APP" >>"$LOG_FILE" 2>&1 || true
|
||||
if [ -d "$TARGET_APP" ]; then
|
||||
mv "$TARGET_APP" "$BACKUP_APP" >>"$LOG_FILE" 2>&1
|
||||
fi
|
||||
if ! mv "$TMP_APP" "$TARGET_APP" >>"$LOG_FILE" 2>&1; then
|
||||
log "move new app failed, trying rollback"
|
||||
rm -rf "$TARGET_APP" >>"$LOG_FILE" 2>&1 || true
|
||||
if [ -d "$BACKUP_APP" ]; then
|
||||
mv "$BACKUP_APP" "$TARGET_APP" >>"$LOG_FILE" 2>&1 || true
|
||||
fi
|
||||
return 1
|
||||
fi
|
||||
rm -rf "$BACKUP_APP" >>"$LOG_FILE" 2>&1 || true
|
||||
xattr -rd com.apple.quarantine "$TARGET_APP" >>"$LOG_FILE" 2>&1 || true
|
||||
return 0
|
||||
}
|
||||
|
||||
relaunch_app() {
|
||||
if /usr/bin/open -n "$TARGET_APP" >>"$LOG_FILE" 2>&1; then
|
||||
return 0
|
||||
fi
|
||||
log "open -n failed, trying binary launch"
|
||||
"$TARGET_APP/$APP_BIN_REL" >>"$LOG_FILE" 2>&1 &
|
||||
return 0
|
||||
}
|
||||
|
||||
log "updater started"
|
||||
@@ -571,21 +836,22 @@ if [ -z "$APP_SRC" ]; then
|
||||
fi
|
||||
|
||||
log "install target: $TARGET_APP"
|
||||
if ! rm -rf "$TARGET_APP" >>"$LOG_FILE" 2>&1 || ! cp -R "$APP_SRC" "$TARGET_APP" >>"$LOG_FILE" 2>&1; then
|
||||
log "direct install failed, trying admin install"
|
||||
run_admin_install >>"$LOG_FILE" 2>&1
|
||||
if ! replace_app_direct; then
|
||||
log "direct replace failed, trying admin replace"
|
||||
run_admin_replace >>"$LOG_FILE" 2>&1
|
||||
fi
|
||||
|
||||
if ! xattr -rd com.apple.quarantine "$TARGET_APP" >>"$LOG_FILE" 2>&1; then
|
||||
log "direct xattr failed, trying admin xattr"
|
||||
run_admin_xattr >>"$LOG_FILE" 2>&1 || true
|
||||
if [ ! -x "$TARGET_APP/$APP_BIN_REL" ]; then
|
||||
log "target app binary missing after replace: $TARGET_APP/$APP_BIN_REL"
|
||||
hdiutil detach "$MOUNT_DIR" -quiet >>"$LOG_FILE" 2>&1 || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
hdiutil detach "$MOUNT_DIR" -quiet >>"$LOG_FILE" 2>&1 || true
|
||||
rm -rf "$MOUNT_DIR" "$DMG" "$STAGED" >>"$LOG_FILE" 2>&1 || true
|
||||
open "$TARGET_APP" >>"$LOG_FILE" 2>&1
|
||||
relaunch_app
|
||||
log "relaunch requested"
|
||||
`, pid, dmgPath, targetApp, stagedDir, mountDir, logPath)
|
||||
`, pid, dmgPath, targetApp, stagedDir, mountDir, logPath)
|
||||
}
|
||||
|
||||
func buildLinuxScript(tarPath, targetExe, stagedDir string, pid int) string {
|
||||
@@ -618,7 +884,12 @@ func detectMacAppPath(exePath string) string {
|
||||
parts := strings.Split(exePath, string(filepath.Separator))
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
if strings.HasSuffix(parts[i], ".app") {
|
||||
return filepath.Join(parts[:i+1]...)
|
||||
appPath := filepath.Join(parts[:i+1]...)
|
||||
// 确保返回绝对路径
|
||||
if !filepath.IsAbs(appPath) {
|
||||
appPath = string(filepath.Separator) + appPath
|
||||
}
|
||||
return appPath
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
func sanitizeSQLForPgLike(dbType string, query string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase":
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
// 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。
|
||||
// 这里做有限次数的迭代,直到输出不再变化。
|
||||
out := query
|
||||
|
||||
@@ -47,14 +47,16 @@ static void gonaviTuneWindowTranslucency(NSWindow *window) {
|
||||
[effectView setMaterial:NSVisualEffectMaterialHUDWindow];
|
||||
[effectView setBlendingMode:NSVisualEffectBlendingModeBehindWindow];
|
||||
[effectView setState:NSVisualEffectStateActive];
|
||||
[effectView setAlphaValue:0.72];
|
||||
// 默认 alpha=0(不可见),由前端根据用户外观设置动态启用
|
||||
[effectView setAlphaValue:0.0];
|
||||
[effectView setWantsLayer:YES];
|
||||
[[effectView layer] setCornerRadius:cornerRadius];
|
||||
[[effectView layer] setMasksToBounds:YES];
|
||||
}
|
||||
|
||||
static void gonaviApplyWindowTranslucencyFix() {
|
||||
for (int i = 0; i < 24; i++) {
|
||||
// 启动时应用窗口透明度修复,减少重试次数以降低启动期 GPU 负载
|
||||
for (int i = 0; i < 8; i++) {
|
||||
dispatch_after(dispatch_time(DISPATCH_TIME_NOW, (int64_t)(i * 250 * NSEC_PER_MSEC)), dispatch_get_main_queue(), ^{
|
||||
for (NSWindow *window in [NSApp windows]) {
|
||||
gonaviTuneWindowTranslucency(window);
|
||||
@@ -62,9 +64,56 @@ static void gonaviApplyWindowTranslucencyFix() {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 动态设置 NSVisualEffectView 的透明度和窗口不透明标志。
|
||||
// alpha <= 0 时窗口标记为 opaque,GPU 不再持续计算窗口背后的模糊效果。
|
||||
static void gonaviSetEffectViewAlpha(double alpha) {
|
||||
dispatch_async(dispatch_get_main_queue(), ^{
|
||||
for (NSWindow *window in [NSApp windows]) {
|
||||
NSView *contentView = [window contentView];
|
||||
if (contentView == nil) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (NSView *subview in [contentView subviews]) {
|
||||
if ([subview isKindOfClass:[NSVisualEffectView class]]) {
|
||||
NSVisualEffectView *effectView = (NSVisualEffectView *)subview;
|
||||
[effectView setAlphaValue:alpha];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (alpha <= 0.01) {
|
||||
[window setOpaque:YES];
|
||||
} else {
|
||||
[window setOpaque:NO];
|
||||
[window setBackgroundColor:[NSColor clearColor]];
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
func applyMacWindowTranslucencyFix() {
|
||||
C.gonaviApplyWindowTranslucencyFix()
|
||||
}
|
||||
|
||||
// setMacWindowTranslucency 根据用户外观设置动态调整 macOS 窗口透明度。
|
||||
// opacity=1.0 且 blur=0 时关闭 NSVisualEffectView(alpha=0),窗口标记为 opaque,
|
||||
// GPU 不再持续计算窗口背后的模糊合成,显著降低 CPU/GPU 温度。
|
||||
func setMacWindowTranslucency(opacity float64, blur float64) {
|
||||
if opacity >= 0.999 && blur <= 0 {
|
||||
C.gonaviSetEffectViewAlpha(C.double(0.0))
|
||||
} else {
|
||||
// 半透明模式:NSVisualEffectView alpha 根据透明度动态映射
|
||||
alpha := (1.0 - opacity) * 1.2
|
||||
if alpha < 0.3 {
|
||||
alpha = 0.3
|
||||
}
|
||||
if alpha > 0.85 {
|
||||
alpha = 0.85
|
||||
}
|
||||
C.gonaviSetEffectViewAlpha(C.double(alpha))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,3 +3,5 @@
|
||||
package app
|
||||
|
||||
func applyMacWindowTranslucencyFix() {}
|
||||
|
||||
func setMacWindowTranslucency(opacity float64, blur float64) {}
|
||||
|
||||
@@ -11,18 +11,31 @@ type SSHConfig struct {
|
||||
|
||||
// ConnectionConfig holds database connection details including SSH
|
||||
type ConnectionConfig struct {
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
UseSSH bool `json:"useSSH"`
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection
|
||||
Database string `json:"database"`
|
||||
UseSSH bool `json:"useSSH"`
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
|
||||
Hosts []string `json:"hosts,omitempty"` // Multi-host addresses: host:port
|
||||
Topology string `json:"topology,omitempty"` // single | replica
|
||||
MySQLReplicaUser string `json:"mysqlReplicaUser,omitempty"` // MySQL replica auth user
|
||||
MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` // MySQL replica auth password
|
||||
ReplicaSet string `json:"replicaSet,omitempty"` // MongoDB replica set name
|
||||
AuthSource string `json:"authSource,omitempty"` // MongoDB authSource
|
||||
ReadPreference string `json:"readPreference,omitempty"` // MongoDB readPreference
|
||||
MongoSRV bool `json:"mongoSrv,omitempty"` // MongoDB use mongodb+srv URI scheme
|
||||
MongoAuthMechanism string `json:"mongoAuthMechanism,omitempty"` // MongoDB authMechanism
|
||||
MongoReplicaUser string `json:"mongoReplicaUser,omitempty"` // MongoDB replica auth user
|
||||
MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` // MongoDB replica auth password
|
||||
}
|
||||
|
||||
// QueryResult is the standard response format for Wails methods
|
||||
@@ -89,3 +102,12 @@ type ChangeSet struct {
|
||||
Updates []UpdateRow `json:"updates"`
|
||||
Deletes []map[string]interface{} `json:"deletes"`
|
||||
}
|
||||
|
||||
type MongoMemberInfo struct {
|
||||
Host string `json:"host"`
|
||||
Role string `json:"role"`
|
||||
State string `json:"state"`
|
||||
StateCode int `json:"stateCode,omitempty"`
|
||||
Healthy bool `json:"healthy"`
|
||||
IsSelf bool `json:"isSelf,omitempty"`
|
||||
}
|
||||
|
||||
@@ -40,6 +40,20 @@ func NewDatabase(dbType string) (Database, error) {
|
||||
return &DamengDB{}, nil
|
||||
case "kingbase":
|
||||
return &KingbaseDB{}, nil
|
||||
case "mongodb":
|
||||
return &MongoDB{}, nil
|
||||
case "sqlserver":
|
||||
return &SqlServerDB{}, nil
|
||||
case "highgo":
|
||||
return &HighGoDB{}, nil
|
||||
case "mariadb":
|
||||
return &MariaDB{}, nil
|
||||
case "sphinx":
|
||||
return &SphinxDB{}, nil
|
||||
case "vastbase":
|
||||
return &VastbaseDB{}, nil
|
||||
case "tdengine":
|
||||
return &TDengineDB{}, nil
|
||||
case "custom":
|
||||
return &CustomDB{}, nil
|
||||
default:
|
||||
|
||||
@@ -95,3 +95,20 @@ func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) {
|
||||
t.Fatalf("dsn 未对包含空格的密码进行引号包裹:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
|
||||
td := &TDengineDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "tdengine",
|
||||
Host: "127.0.0.1",
|
||||
Port: 6041,
|
||||
User: "root",
|
||||
Password: "taosdata",
|
||||
Database: "power",
|
||||
}
|
||||
|
||||
dsn := td.getDSN(cfg)
|
||||
if !strings.HasPrefix(dsn, "root:taosdata@ws(127.0.0.1:6041)/power") {
|
||||
t.Fatalf("tdengine dsn 格式不正确:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
628
internal/db/highgo_impl.go
Normal file
628
internal/db/highgo_impl.go
Normal file
@@ -0,0 +1,628 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/highgo/pq-sm3" // HighGo uses dedicated SM3-capable driver
|
||||
)
|
||||
|
||||
// HighGoDB implements Database interface for HighGo (瀚高) database
|
||||
// HighGo is a PostgreSQL-compatible database, so we reuse PostgreSQL driver
|
||||
type HighGoDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||||
dbname := config.Database
|
||||
if dbname == "" {
|
||||
dbname = "highgo" // HighGo default database
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "postgres",
|
||||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
Path: "/" + dbname,
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", "disable")
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
|
||||
if config.UseSSH {
|
||||
logger.Infof("HighGo 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
h.forwarder = forwarder
|
||||
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = h.getDSN(localConfig)
|
||||
logger.Infof("HighGo 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = h.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("highgo", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
h.conn = db
|
||||
h.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := h.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Close() error {
|
||||
if h.forwarder != nil {
|
||||
if err := h.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 HighGo SSH 端口转发失败:%v", err)
|
||||
}
|
||||
h.forwarder = nil
|
||||
}
|
||||
|
||||
if h.conn != nil {
|
||||
return h.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Ping() error {
|
||||
if h.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
timeout := h.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return h.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (h *HighGoDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if h.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := h.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if h.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := h.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if h.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := h.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Exec(query string) (int64, error) {
|
||||
if h.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := h.conn.Exec(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := h.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["datname"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetTables(dbName string) ([]string, error) {
|
||||
query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg_%' ORDER BY schemaname, tablename"
|
||||
data, _, err := h.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
schema, okSchema := row["schemaname"]
|
||||
name, okName := row["tablename"]
|
||||
if okSchema && okName {
|
||||
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
|
||||
continue
|
||||
}
|
||||
if okName {
|
||||
tables = append(tables, fmt.Sprintf("%v", name))
|
||||
}
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
return fmt.Sprintf("-- SHOW CREATE TABLE not fully supported for HighGo in this version.\n-- Table: %s", tableName), nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
col_description(a.attrelid, a.attnum) AS comment,
|
||||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||||
LEFT JOIN (
|
||||
SELECT i.indrelid, a3.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||||
WHERE i.indisprimary
|
||||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||||
WHERE c.relkind IN ('r', 'p')
|
||||
AND n.nspname = '%s'
|
||||
AND c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum`, esc(schema), esc(table))
|
||||
|
||||
data, _, err := h.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if v, ok := row["comment"]; ok && v != nil {
|
||||
col.Comment = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
if v, ok := row["column_default"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||||
col.Extra = "auto_increment"
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname AS index_name,
|
||||
a.attname AS column_name,
|
||||
ix.indisunique AS is_unique,
|
||||
x.ordinality AS seq_in_index,
|
||||
am.amname AS index_type
|
||||
FROM pg_class t
|
||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||
JOIN pg_index ix ON t.oid = ix.indrelid
|
||||
JOIN pg_class i ON i.oid = ix.indexrelid
|
||||
JOIN pg_am am ON i.relam = am.oid
|
||||
JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE
|
||||
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
|
||||
WHERE t.relkind IN ('r', 'p')
|
||||
AND t.relname = '%s'
|
||||
AND n.nspname = '%s'
|
||||
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := h.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseBool := func(v interface{}) bool {
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return val
|
||||
case string:
|
||||
s := strings.ToLower(strings.TrimSpace(val))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
default:
|
||||
s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
}
|
||||
}
|
||||
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
isUnique := false
|
||||
if v, ok := row["is_unique"]; ok && v != nil {
|
||||
isUnique = parseBool(v)
|
||||
}
|
||||
|
||||
nonUnique := 1
|
||||
if isUnique {
|
||||
nonUnique = 0
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if v, ok := row["seq_in_index"]; ok && v != nil {
|
||||
seq = parseInt(v)
|
||||
}
|
||||
|
||||
indexType := ""
|
||||
if v, ok := row["index_type"]; ok && v != nil {
|
||||
indexType = strings.ToUpper(fmt.Sprintf("%v", v))
|
||||
}
|
||||
if indexType == "" {
|
||||
indexType = "BTREE"
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["index_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: indexType,
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name AS constraint_name,
|
||||
kcu.column_name AS column_name,
|
||||
ccu.table_schema AS foreign_table_schema,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_name = '%s'
|
||||
AND tc.table_schema = '%s'
|
||||
ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := h.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
refSchema := ""
|
||||
if v, ok := row["foreign_table_schema"]; ok && v != nil {
|
||||
refSchema = fmt.Sprintf("%v", v)
|
||||
}
|
||||
refTable := fmt.Sprintf("%v", row["foreign_table_name"])
|
||||
refTableName := refTable
|
||||
if strings.TrimSpace(refSchema) != "" {
|
||||
refTableName = fmt.Sprintf("%s.%s", refSchema, refTable)
|
||||
}
|
||||
|
||||
fk := connection.ForeignKeyDefinition{
|
||||
Name: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
RefTableName: refTableName,
|
||||
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
|
||||
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT trigger_name, action_timing, event_manipulation, action_statement
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s'
|
||||
AND event_object_schema = '%s'
|
||||
ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := h.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
trig := connection.TriggerDefinition{
|
||||
Name: fmt.Sprintf("%v", row["trigger_name"]),
|
||||
Timing: fmt.Sprintf("%v", row["action_timing"]),
|
||||
Event: fmt.Sprintf("%v", row["event_manipulation"]),
|
||||
Statement: fmt.Sprintf("%v", row["action_statement"]),
|
||||
}
|
||||
triggers = append(triggers, trig)
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
query := `
|
||||
SELECT table_schema, table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
AND table_schema NOT LIKE 'pg_%'
|
||||
ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
data, _, err := h.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||||
table := fmt.Sprintf("%v", row["table_name"])
|
||||
tableName := table
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||||
}
|
||||
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: tableName,
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
}
|
||||
cols = append(cols, col)
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if h.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
tx, err := h.conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
quoteIdent := func(name string) string {
|
||||
n := strings.TrimSpace(name)
|
||||
n = strings.Trim(n, "\"")
|
||||
n = strings.ReplaceAll(n, "\"", "\"\"")
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
schema := ""
|
||||
table := strings.TrimSpace(tableName)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
qualifiedTable := ""
|
||||
if schema != "" {
|
||||
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
|
||||
} else {
|
||||
qualifiedTable = quoteIdent(table)
|
||||
}
|
||||
|
||||
// 1. Deletes
|
||||
for _, pk := range changes.Deletes {
|
||||
var wheres []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
for k, v := range pk {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Updates
|
||||
for _, update := range changes.Updates {
|
||||
var sets []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
|
||||
for k, v := range update.Values {
|
||||
idx++
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var wheres []string
|
||||
for k, v := range update.Keys {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Inserts
|
||||
for _, row := range changes.Inserts {
|
||||
var cols []string
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
|
||||
for k, v := range row {
|
||||
idx++
|
||||
cols = append(cols, quoteIdent(k))
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
409
internal/db/mariadb_impl.go
Normal file
409
internal/db/mariadb_impl.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
// MariaDB implements Database interface for MariaDB
|
||||
// MariaDB is MySQL-compatible, so we reuse the MySQL driver
|
||||
type MariaDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (m *MariaDB) getDSN(config connection.ConnectionConfig) string {
|
||||
database := config.Database
|
||||
protocol := "tcp"
|
||||
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
protocol = netName
|
||||
address = fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
} else {
|
||||
logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
|
||||
}
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
|
||||
config.User, config.Password, protocol, address, database, timeout)
|
||||
}
|
||||
|
||||
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := m.getDSN(config)
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
m.conn = db
|
||||
m.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := m.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) Close() error {
|
||||
if m.conn != nil {
|
||||
return m.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) Ping() error {
|
||||
if m.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
timeout := m.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return m.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (m *MariaDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := m.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (m *MariaDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := m.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (m *MariaDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := m.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (m *MariaDB) Exec(query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := m.conn.Exec(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := m.Query("SHOW DATABASES")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["Database"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
} else if val, ok := row["database"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetTables(dbName string) ([]string, error) {
|
||||
query := "SHOW TABLES"
|
||||
if dbName != "" {
|
||||
query = fmt.Sprintf("SHOW TABLES FROM `%s`", dbName)
|
||||
}
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
for _, v := range row {
|
||||
tables = append(tables, fmt.Sprintf("%v", v))
|
||||
break
|
||||
}
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", dbName, tableName)
|
||||
if dbName == "" {
|
||||
query = fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
|
||||
}
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
if val, ok := data[0]["Create Table"]; ok {
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName)
|
||||
if dbName == "" {
|
||||
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
|
||||
}
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["Field"]),
|
||||
Type: fmt.Sprintf("%v", row["Type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["Null"]),
|
||||
Key: fmt.Sprintf("%v", row["Key"]),
|
||||
Extra: fmt.Sprintf("%v", row["Extra"]),
|
||||
Comment: fmt.Sprintf("%v", row["Comment"]),
|
||||
}
|
||||
|
||||
if row["Default"] != nil {
|
||||
d := fmt.Sprintf("%v", row["Default"])
|
||||
col.Default = &d
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", dbName, tableName)
|
||||
if dbName == "" {
|
||||
query = fmt.Sprintf("SHOW INDEX FROM `%s`", tableName)
|
||||
}
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
nonUnique := 0
|
||||
if val, ok := row["Non_unique"]; ok {
|
||||
if f, ok := val.(float64); ok {
|
||||
nonUnique = int(f)
|
||||
} else if i, ok := val.(int64); ok {
|
||||
nonUnique = int(i)
|
||||
}
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if val, ok := row["Seq_in_index"]; ok {
|
||||
if f, ok := val.(float64); ok {
|
||||
seq = int(f)
|
||||
} else if i, ok := val.(int64); ok {
|
||||
seq = int(i)
|
||||
}
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["Key_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["Column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: fmt.Sprintf("%v", row["Index_type"]),
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
query := fmt.Sprintf(`SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
|
||||
FROM information_schema.KEY_COLUMN_USAGE
|
||||
WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND REFERENCED_TABLE_NAME IS NOT NULL`, dbName, tableName)
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
fk := connection.ForeignKeyDefinition{
|
||||
Name: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]),
|
||||
RefTableName: fmt.Sprintf("%v", row["REFERENCED_TABLE_NAME"]),
|
||||
RefColumnName: fmt.Sprintf("%v", row["REFERENCED_COLUMN_NAME"]),
|
||||
ConstraintName: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
query := fmt.Sprintf("SHOW TRIGGERS FROM `%s` WHERE `Table` = '%s'", dbName, tableName)
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
trig := connection.TriggerDefinition{
|
||||
Name: fmt.Sprintf("%v", row["Trigger"]),
|
||||
Timing: fmt.Sprintf("%v", row["Timing"]),
|
||||
Event: fmt.Sprintf("%v", row["Event"]),
|
||||
Statement: fmt.Sprintf("%v", row["Statement"]),
|
||||
}
|
||||
triggers = append(triggers, trig)
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if m.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
tx, err := m.conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 1. Deletes
|
||||
for _, pk := range changes.Deletes {
|
||||
var wheres []string
|
||||
var args []interface{}
|
||||
for k, v := range pk {
|
||||
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Updates
|
||||
for _, update := range changes.Updates {
|
||||
var sets []string
|
||||
var args []interface{}
|
||||
|
||||
for k, v := range update.Values {
|
||||
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var wheres []string
|
||||
for k, v := range update.Keys {
|
||||
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Inserts
|
||||
for _, row := range changes.Inserts {
|
||||
var cols []string
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
|
||||
for k, v := range row {
|
||||
cols = append(cols, fmt.Sprintf("`%s`", k))
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, normalizeMySQLDateTimeValue(v))
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (m *MariaDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
|
||||
if dbName == "" {
|
||||
return nil, fmt.Errorf("database name required for GetAllColumns")
|
||||
}
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: fmt.Sprintf("%v", row["TABLE_NAME"]),
|
||||
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
|
||||
Type: fmt.Sprintf("%v", row["COLUMN_TYPE"]),
|
||||
}
|
||||
cols = append(cols, col)
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
1144
internal/db/mongodb_impl.go
Normal file
1144
internal/db/mongodb_impl.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,16 +22,161 @@ type MySQLDB struct {
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
const defaultMySQLPort = 3306
|
||||
|
||||
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(text, "[") {
|
||||
end := strings.Index(text, "]")
|
||||
if end < 0 {
|
||||
return text, defaultPort, true
|
||||
}
|
||||
host := text[1:end]
|
||||
portText := strings.TrimSpace(text[end+1:])
|
||||
if strings.HasPrefix(portText, ":") {
|
||||
if p, err := strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(portText, ":"))); err == nil && p > 0 {
|
||||
return host, p, true
|
||||
}
|
||||
}
|
||||
return host, defaultPort, true
|
||||
}
|
||||
|
||||
lastColon := strings.LastIndex(text, ":")
|
||||
if lastColon > 0 && strings.Count(text, ":") == 1 {
|
||||
host := strings.TrimSpace(text[:lastColon])
|
||||
portText := strings.TrimSpace(text[lastColon+1:])
|
||||
if host != "" {
|
||||
if p, err := strconv.Atoi(portText); err == nil && p > 0 {
|
||||
return host, p, true
|
||||
}
|
||||
return host, defaultPort, true
|
||||
}
|
||||
}
|
||||
|
||||
return text, defaultPort, true
|
||||
}
|
||||
|
||||
func normalizeMySQLAddress(host string, port int) string {
|
||||
h := strings.TrimSpace(host)
|
||||
if h == "" {
|
||||
h = "localhost"
|
||||
}
|
||||
p := port
|
||||
if p <= 0 {
|
||||
p = defaultMySQLPort
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", h, p)
|
||||
}
|
||||
|
||||
func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
uriText := strings.TrimSpace(config.URI)
|
||||
if uriText == "" {
|
||||
return config
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(uriText), "mysql://") {
|
||||
return config
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(uriText)
|
||||
if err != nil {
|
||||
return config
|
||||
}
|
||||
|
||||
if parsed.User != nil {
|
||||
if config.User == "" {
|
||||
config.User = parsed.User.Username()
|
||||
}
|
||||
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
|
||||
config.Password = pass
|
||||
}
|
||||
}
|
||||
|
||||
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
|
||||
config.Database = dbName
|
||||
}
|
||||
|
||||
defaultPort := config.Port
|
||||
if defaultPort <= 0 {
|
||||
defaultPort = defaultMySQLPort
|
||||
}
|
||||
|
||||
hostsFromURI := make([]string, 0, 4)
|
||||
hostText := strings.TrimSpace(parsed.Host)
|
||||
if hostText != "" {
|
||||
for _, entry := range strings.Split(hostText, ",") {
|
||||
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
|
||||
config.Hosts = hostsFromURI
|
||||
}
|
||||
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
|
||||
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
|
||||
if ok {
|
||||
config.Host = host
|
||||
config.Port = port
|
||||
}
|
||||
}
|
||||
|
||||
if config.Topology == "" {
|
||||
topology := strings.TrimSpace(parsed.Query().Get("topology"))
|
||||
if topology != "" {
|
||||
config.Topology = strings.ToLower(topology)
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func collectMySQLAddresses(config connection.ConnectionConfig) []string {
|
||||
defaultPort := config.Port
|
||||
if defaultPort <= 0 {
|
||||
defaultPort = defaultMySQLPort
|
||||
}
|
||||
|
||||
candidates := make([]string, 0, len(config.Hosts)+1)
|
||||
if len(config.Hosts) > 0 {
|
||||
candidates = append(candidates, config.Hosts...)
|
||||
} else {
|
||||
candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort))
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(candidates))
|
||||
seen := make(map[string]struct{}, len(candidates))
|
||||
for _, entry := range candidates {
|
||||
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
normalized := normalizeMySQLAddress(host, port)
|
||||
if _, exists := seen[normalized]; exists {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
result = append(result, normalized)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
|
||||
database := config.Database
|
||||
protocol := "tcp"
|
||||
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
address := normalizeMySQLAddress(config.Host, config.Port)
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
protocol = netName
|
||||
address = fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
address = normalizeMySQLAddress(config.Host, config.Port)
|
||||
} else {
|
||||
logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
|
||||
}
|
||||
@@ -41,20 +188,67 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
|
||||
config.User, config.Password, protocol, address, database, timeout)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := m.getDSN(config)
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
m.conn = db
|
||||
m.pingTimeout = getConnectTimeout(config)
|
||||
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
primaryUser := strings.TrimSpace(config.User)
|
||||
primaryPassword := config.Password
|
||||
replicaUser := strings.TrimSpace(config.MySQLReplicaUser)
|
||||
replicaPassword := config.MySQLReplicaPassword
|
||||
|
||||
// Force verification
|
||||
if err := m.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
if addressIndex > 0 && replicaUser != "" {
|
||||
return replicaUser, replicaPassword
|
||||
}
|
||||
return nil
|
||||
|
||||
if primaryUser == "" && replicaUser != "" {
|
||||
return replicaUser, replicaPassword
|
||||
}
|
||||
|
||||
return config.User, primaryPassword
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
|
||||
runConfig := applyMySQLURI(config)
|
||||
addresses := collectMySQLAddresses(runConfig)
|
||||
if len(addresses) == 0 {
|
||||
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
|
||||
}
|
||||
|
||||
var errorDetails []string
|
||||
for index, address := range addresses {
|
||||
candidateConfig := runConfig
|
||||
host, port, ok := parseHostPortWithDefault(address, defaultMySQLPort)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
candidateConfig.Host = host
|
||||
candidateConfig.Port = port
|
||||
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
|
||||
|
||||
dsn := m.getDSN(candidateConfig)
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
|
||||
continue
|
||||
}
|
||||
|
||||
timeout := getConnectTimeout(candidateConfig)
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
pingErr := db.PingContext(ctx)
|
||||
cancel()
|
||||
if pingErr != nil {
|
||||
_ = db.Close()
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
|
||||
continue
|
||||
}
|
||||
|
||||
m.conn = db
|
||||
m.pingTimeout = timeout
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(errorDetails) == 0 {
|
||||
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
|
||||
}
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ";"))
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Close() error {
|
||||
|
||||
103
internal/db/sphinx_impl.go
Normal file
103
internal/db/sphinx_impl.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
const sphinxDefaultDatabaseName = "default"
|
||||
|
||||
// SphinxDB 复用 MySQL 协议实现,并在数据库列表不可用时提供兜底。
|
||||
type SphinxDB struct {
|
||||
MySQLDB
|
||||
fallbackDatabase string
|
||||
}
|
||||
|
||||
func isSphinxUnsupportedFeatureError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
text := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
keywords := []string{
|
||||
"not supported",
|
||||
"unsupported",
|
||||
"syntax error",
|
||||
"unknown table",
|
||||
"unknown column",
|
||||
"doesn't exist",
|
||||
}
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(text, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SphinxDB) Connect(config connection.ConnectionConfig) error {
|
||||
runConfig := applyMySQLURI(config)
|
||||
s.fallbackDatabase = strings.TrimSpace(runConfig.Database)
|
||||
return s.MySQLDB.Connect(config)
|
||||
}
|
||||
|
||||
func (s *SphinxDB) resolveDatabaseName(dbName string) string {
|
||||
name := strings.TrimSpace(dbName)
|
||||
if name == "" {
|
||||
return s.fallbackDatabase
|
||||
}
|
||||
if strings.EqualFold(name, sphinxDefaultDatabaseName) && s.fallbackDatabase == "" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetDatabases() ([]string, error) {
|
||||
dbs, err := s.MySQLDB.GetDatabases()
|
||||
if err == nil && len(dbs) > 0 {
|
||||
return dbs, nil
|
||||
}
|
||||
if s.fallbackDatabase != "" {
|
||||
return []string{s.fallbackDatabase}, nil
|
||||
}
|
||||
return []string{sphinxDefaultDatabaseName}, nil
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetTables(dbName string) ([]string, error) {
|
||||
return s.MySQLDB.GetTables(s.resolveDatabaseName(dbName))
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
return s.MySQLDB.GetCreateStatement(s.resolveDatabaseName(dbName), tableName)
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return s.MySQLDB.GetColumns(s.resolveDatabaseName(dbName), tableName)
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
return s.MySQLDB.GetAllColumns(s.resolveDatabaseName(dbName))
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return s.MySQLDB.GetIndexes(s.resolveDatabaseName(dbName), tableName)
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
fks, err := s.MySQLDB.GetForeignKeys(s.resolveDatabaseName(dbName), tableName)
|
||||
if err != nil && isSphinxUnsupportedFeatureError(err) {
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
}
|
||||
return fks, err
|
||||
}
|
||||
|
||||
func (s *SphinxDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
triggers, err := s.MySQLDB.GetTriggers(s.resolveDatabaseName(dbName), tableName)
|
||||
if err != nil && isSphinxUnsupportedFeatureError(err) {
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
}
|
||||
return triggers, err
|
||||
}
|
||||
635
internal/db/sqlserver_impl.go
Normal file
635
internal/db/sqlserver_impl.go
Normal file
@@ -0,0 +1,635 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/microsoft/go-mssqldb"
|
||||
)
|
||||
|
||||
type SqlServerDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
// quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation
|
||||
func quoteBracket(name string) string {
|
||||
return strings.ReplaceAll(name, "]", "]]")
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// sqlserver://user:password@host:port?database=dbname
|
||||
dbname := config.Database
|
||||
if dbname == "" {
|
||||
dbname = "master"
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("database", dbname)
|
||||
q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
q.Set("encrypt", "disable")
|
||||
q.Set("TrustServerCertificate", "true")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
|
||||
if config.UseSSH {
|
||||
logger.Infof("SQL Server 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
s.forwarder = forwarder
|
||||
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = s.getDSN(localConfig)
|
||||
logger.Infof("SQL Server 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = s.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlserver", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
s.conn = db
|
||||
s.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := s.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Close() error {
|
||||
if s.forwarder != nil {
|
||||
if err := s.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 SQL Server SSH 端口转发失败:%v", err)
|
||||
}
|
||||
s.forwarder = nil
|
||||
}
|
||||
|
||||
if s.conn != nil {
|
||||
return s.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Ping() error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
timeout := s.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return s.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := s.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := s.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Exec(query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := s.conn.Exec(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetDatabases() ([]string, error) {
|
||||
query := "SELECT name FROM sys.databases WHERE state_desc = 'ONLINE' ORDER BY name"
|
||||
data, _, err := s.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["name"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetTables(dbName string) ([]string, error) {
|
||||
// SQL Server uses schema.table format, default schema is dbo
|
||||
safeDB := quoteBracket(dbName)
|
||||
query := fmt.Sprintf(`
|
||||
SELECT s.name AS schema_name, t.name AS table_name
|
||||
FROM [%s].sys.tables t
|
||||
JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id
|
||||
WHERE t.type = 'U'
|
||||
ORDER BY s.name, t.name`, safeDB, safeDB)
|
||||
|
||||
data, _, err := s.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
schema, okSchema := row["schema_name"]
|
||||
name, okName := row["table_name"]
|
||||
if okSchema && okName {
|
||||
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
|
||||
continue
|
||||
}
|
||||
if okName {
|
||||
tables = append(tables, fmt.Sprintf("%v", name))
|
||||
}
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
return fmt.Sprintf("-- SHOW CREATE TABLE not supported for SQL Server in this version.\n-- Table: %s.%s", dbName, tableName), nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
schema := "dbo"
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
safeDB := quoteBracket(dbName)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
c.name AS column_name,
|
||||
t.name + CASE
|
||||
WHEN t.name IN ('varchar', 'nvarchar', 'char', 'nchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(CASE WHEN t.name IN ('nvarchar', 'nchar') THEN c.max_length / 2 ELSE c.max_length END AS VARCHAR) END + ')'
|
||||
WHEN t.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR) + ',' + CAST(c.scale AS VARCHAR) + ')'
|
||||
ELSE ''
|
||||
END AS data_type,
|
||||
CASE WHEN c.is_nullable = 1 THEN 'YES' ELSE 'NO' END AS is_nullable,
|
||||
dc.definition AS column_default,
|
||||
ep.value AS comment,
|
||||
CASE WHEN pk.column_id IS NOT NULL THEN 'PRI' ELSE '' END AS column_key,
|
||||
CASE WHEN c.is_identity = 1 THEN 'auto_increment' ELSE '' END AS extra
|
||||
FROM [%s].sys.columns c
|
||||
JOIN [%s].sys.types t ON c.user_type_id = t.user_type_id
|
||||
JOIN [%s].sys.tables tb ON c.object_id = tb.object_id
|
||||
JOIN [%s].sys.schemas s ON tb.schema_id = s.schema_id
|
||||
LEFT JOIN [%s].sys.default_constraints dc ON c.default_object_id = dc.object_id
|
||||
LEFT JOIN [%s].sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id AND ep.name = 'MS_Description'
|
||||
LEFT JOIN (
|
||||
SELECT ic.object_id, ic.column_id
|
||||
FROM [%s].sys.index_columns ic
|
||||
JOIN [%s].sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id
|
||||
WHERE i.is_primary_key = 1
|
||||
) pk ON pk.object_id = c.object_id AND pk.column_id = c.column_id
|
||||
WHERE s.name = '%s' AND tb.name = '%s'
|
||||
ORDER BY c.column_id`,
|
||||
safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB,
|
||||
esc(schema), esc(table))
|
||||
|
||||
data, _, err := s.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||||
Extra: fmt.Sprintf("%v", row["extra"]),
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if v, ok := row["comment"]; ok && v != nil {
|
||||
col.Comment = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
if v, ok := row["column_default"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
safeDB := quoteBracket(dbName)
|
||||
query := fmt.Sprintf(`
|
||||
SELECT s.name AS schema_name, t.name AS table_name, c.name AS column_name, tp.name AS data_type
|
||||
FROM [%s].sys.columns c
|
||||
JOIN [%s].sys.tables t ON c.object_id = t.object_id
|
||||
JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id
|
||||
JOIN [%s].sys.types tp ON c.user_type_id = tp.user_type_id
|
||||
WHERE t.type = 'U'
|
||||
ORDER BY s.name, t.name, c.column_id`, safeDB, safeDB, safeDB, safeDB)
|
||||
|
||||
data, _, err := s.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
schema := fmt.Sprintf("%v", row["schema_name"])
|
||||
table := fmt.Sprintf("%v", row["table_name"])
|
||||
tableName := fmt.Sprintf("%s.%s", schema, table)
|
||||
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: tableName,
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
}
|
||||
cols = append(cols, col)
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
schema := "dbo"
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
safeDB := quoteBracket(dbName)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
i.name AS index_name,
|
||||
c.name AS column_name,
|
||||
i.is_unique,
|
||||
ic.key_ordinal AS seq_in_index,
|
||||
i.type_desc AS index_type
|
||||
FROM [%s].sys.indexes i
|
||||
JOIN [%s].sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
|
||||
JOIN [%s].sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id
|
||||
JOIN [%s].sys.tables t ON i.object_id = t.object_id
|
||||
JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id
|
||||
WHERE s.name = '%s' AND t.name = '%s' AND i.name IS NOT NULL
|
||||
ORDER BY i.name, ic.key_ordinal`,
|
||||
safeDB, safeDB, safeDB, safeDB, safeDB, esc(schema), esc(table))
|
||||
|
||||
data, _, err := s.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
isUnique := false
|
||||
if v, ok := row["is_unique"]; ok && v != nil {
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
isUnique = val
|
||||
case int64:
|
||||
isUnique = val == 1
|
||||
}
|
||||
}
|
||||
|
||||
nonUnique := 1
|
||||
if isUnique {
|
||||
nonUnique = 0
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if v, ok := row["seq_in_index"]; ok && v != nil {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
seq = val
|
||||
case int64:
|
||||
seq = int(val)
|
||||
}
|
||||
}
|
||||
|
||||
indexType := "NONCLUSTERED"
|
||||
if v, ok := row["index_type"]; ok && v != nil {
|
||||
indexType = strings.ToUpper(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["index_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: indexType,
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
schema := "dbo"
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
safeDB := quoteBracket(dbName)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
fk.name AS constraint_name,
|
||||
c.name AS column_name,
|
||||
rs.name AS foreign_schema,
|
||||
rt.name AS foreign_table,
|
||||
rc.name AS foreign_column
|
||||
FROM [%s].sys.foreign_keys fk
|
||||
JOIN [%s].sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id
|
||||
JOIN [%s].sys.columns c ON fkc.parent_object_id = c.object_id AND fkc.parent_column_id = c.column_id
|
||||
JOIN [%s].sys.tables t ON fk.parent_object_id = t.object_id
|
||||
JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id
|
||||
JOIN [%s].sys.tables rt ON fk.referenced_object_id = rt.object_id
|
||||
JOIN [%s].sys.schemas rs ON rt.schema_id = rs.schema_id
|
||||
JOIN [%s].sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id
|
||||
WHERE s.name = '%s' AND t.name = '%s'
|
||||
ORDER BY fk.name`,
|
||||
safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, safeDB, esc(schema), esc(table))
|
||||
|
||||
data, _, err := s.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
refSchema := fmt.Sprintf("%v", row["foreign_schema"])
|
||||
refTable := fmt.Sprintf("%v", row["foreign_table"])
|
||||
refTableName := fmt.Sprintf("%s.%s", refSchema, refTable)
|
||||
|
||||
fk := connection.ForeignKeyDefinition{
|
||||
Name: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
RefTableName: refTableName,
|
||||
RefColumnName: fmt.Sprintf("%v", row["foreign_column"]),
|
||||
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
schema := "dbo"
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
safeDB := quoteBracket(dbName)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
tr.name AS trigger_name,
|
||||
CASE WHEN tr.is_instead_of_trigger = 1 THEN 'INSTEAD OF' ELSE 'AFTER' END AS timing,
|
||||
STUFF((
|
||||
SELECT ', ' + te.type_desc
|
||||
FROM [%s].sys.trigger_events te
|
||||
WHERE te.object_id = tr.object_id
|
||||
FOR XML PATH('')
|
||||
), 1, 2, '') AS event,
|
||||
OBJECT_DEFINITION(tr.object_id) AS statement
|
||||
FROM [%s].sys.triggers tr
|
||||
JOIN [%s].sys.tables t ON tr.parent_id = t.object_id
|
||||
JOIN [%s].sys.schemas s ON t.schema_id = s.schema_id
|
||||
WHERE s.name = '%s' AND t.name = '%s'
|
||||
ORDER BY tr.name`,
|
||||
safeDB, safeDB, safeDB, safeDB, esc(schema), esc(table))
|
||||
|
||||
data, _, err := s.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
trig := connection.TriggerDefinition{
|
||||
Name: fmt.Sprintf("%v", row["trigger_name"]),
|
||||
Timing: fmt.Sprintf("%v", row["timing"]),
|
||||
Event: fmt.Sprintf("%v", row["event"]),
|
||||
Statement: "",
|
||||
}
|
||||
if v, ok := row["statement"]; ok && v != nil {
|
||||
trig.Statement = fmt.Sprintf("%v", v)
|
||||
}
|
||||
triggers = append(triggers, trig)
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
tx, err := s.conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
quoteIdent := func(name string) string {
|
||||
n := strings.TrimSpace(name)
|
||||
n = strings.Trim(n, "[]")
|
||||
n = strings.ReplaceAll(n, "]", "]]")
|
||||
if n == "" {
|
||||
return "[]"
|
||||
}
|
||||
return "[" + n + "]"
|
||||
}
|
||||
|
||||
schema := "dbo"
|
||||
table := strings.TrimSpace(tableName)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
qualifiedTable := fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
|
||||
|
||||
// 1. Deletes
|
||||
for _, pk := range changes.Deletes {
|
||||
var wheres []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
for k, v := range pk {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = @p%d", quoteIdent(k), idx))
|
||||
args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v))
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Updates
|
||||
for _, update := range changes.Updates {
|
||||
var sets []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
|
||||
for k, v := range update.Values {
|
||||
idx++
|
||||
sets = append(sets, fmt.Sprintf("%s = @p%d", quoteIdent(k), idx))
|
||||
args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v))
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var wheres []string
|
||||
for k, v := range update.Keys {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = @p%d", quoteIdent(k), idx))
|
||||
args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v))
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Inserts
|
||||
for _, row := range changes.Inserts {
|
||||
var cols []string
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
|
||||
for k, v := range row {
|
||||
idx++
|
||||
cols = append(cols, quoteIdent(k))
|
||||
placeholders = append(placeholders, fmt.Sprintf("@p%d", idx))
|
||||
args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v))
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
398
internal/db/tdengine_impl.go
Normal file
398
internal/db/tdengine_impl.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/taosdata/driver-go/v3/taosWS"
|
||||
)
|
||||
|
||||
// TDengineDB implements Database interface for TDengine.
|
||||
// Uses taosWS driver via WebSocket (通常通过 taosAdapter 提供服务)。
|
||||
type TDengineDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
|
||||
user := strings.TrimSpace(config.User)
|
||||
if user == "" {
|
||||
user = "root"
|
||||
}
|
||||
|
||||
pass := config.Password
|
||||
dbName := strings.TrimSpace(config.Database)
|
||||
path := "/"
|
||||
if dbName != "" {
|
||||
path = "/" + dbName
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s@ws(%s)%s", user, pass, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
|
||||
if config.UseSSH {
|
||||
logger.Infof("TDengine 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
t.forwarder = forwarder
|
||||
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
dsn = t.getDSN(localConfig)
|
||||
logger.Infof("TDengine 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = t.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("taosWS", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
t.conn = db
|
||||
t.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := t.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Close() error {
|
||||
if t.forwarder != nil {
|
||||
if err := t.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 TDengine SSH 端口转发失败:%v", err)
|
||||
}
|
||||
t.forwarder = nil
|
||||
}
|
||||
|
||||
if t.conn != nil {
|
||||
return t.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Ping() error {
|
||||
if t.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
timeout := t.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return t.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (t *TDengineDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if t.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := t.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if t.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := t.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (t *TDengineDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if t.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := t.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Exec(query string) (int64, error) {
|
||||
if t.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := t.conn.Exec(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := t.Query("SHOW DATABASES")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := getValueFromRow(row, "name", "database", "Database", "db_name"); ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
continue
|
||||
}
|
||||
for _, val := range row {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
break
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetTables(dbName string) ([]string, error) {
|
||||
queries := make([]string, 0, 2)
|
||||
if strings.TrimSpace(dbName) != "" {
|
||||
queries = append(queries, fmt.Sprintf("SHOW TABLES FROM `%s`", escapeBacktickIdent(dbName)))
|
||||
}
|
||||
queries = append(queries, "SHOW TABLES")
|
||||
|
||||
var lastErr error
|
||||
for _, query := range queries {
|
||||
data, _, err := t.Query(query)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
if val, ok := getValueFromRow(row, "table_name", "tablename", "name", "Table", "table"); ok {
|
||||
tables = append(tables, fmt.Sprintf("%v", val))
|
||||
continue
|
||||
}
|
||||
for _, val := range row {
|
||||
tables = append(tables, fmt.Sprintf("%v", val))
|
||||
break
|
||||
}
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
qualified := quoteTDengineTable(dbName, tableName)
|
||||
queries := []string{
|
||||
fmt.Sprintf("SHOW CREATE TABLE %s", qualified),
|
||||
fmt.Sprintf("SHOW CREATE STABLE %s", qualified),
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, query := range queries {
|
||||
data, _, err := t.Query(query)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
row := data[0]
|
||||
if val, ok := getValueFromRow(row, "Create Table", "create table", "Create Stable", "create stable", "SQL", "sql"); ok {
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
}
|
||||
|
||||
longest := ""
|
||||
for _, val := range row {
|
||||
text := fmt.Sprintf("%v", val)
|
||||
if strings.Contains(strings.ToUpper(text), "CREATE ") && len(text) > len(longest) {
|
||||
longest = text
|
||||
}
|
||||
}
|
||||
if longest != "" {
|
||||
return longest, nil
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
return "", fmt.Errorf("create statement not found")
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
query := fmt.Sprintf("DESCRIBE %s", quoteTDengineTable(dbName, tableName))
|
||||
data, _, err := t.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]connection.ColumnDefinition, 0, len(data))
|
||||
for _, row := range data {
|
||||
name, _ := getValueFromRow(row, "Field", "field", "col_name", "column_name", "name")
|
||||
colType, _ := getValueFromRow(row, "Type", "type", "data_type")
|
||||
note, _ := getValueFromRow(row, "Note", "note", "Extra", "extra")
|
||||
nullable, okNull := getValueFromRow(row, "Null", "null", "nullable")
|
||||
comment, _ := getValueFromRow(row, "Comment", "comment")
|
||||
defaultVal, hasDefault := getValueFromRow(row, "Default", "default")
|
||||
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", name),
|
||||
Type: fmt.Sprintf("%v", colType),
|
||||
Nullable: "YES",
|
||||
Key: "",
|
||||
Extra: fmt.Sprintf("%v", note),
|
||||
Comment: fmt.Sprintf("%v", comment),
|
||||
}
|
||||
|
||||
if okNull {
|
||||
col.Nullable = strings.ToUpper(fmt.Sprintf("%v", nullable))
|
||||
}
|
||||
|
||||
noteUpper := strings.ToUpper(fmt.Sprintf("%v", note))
|
||||
if strings.Contains(noteUpper, "TAG") {
|
||||
col.Key = "TAG"
|
||||
}
|
||||
|
||||
if hasDefault && defaultVal != nil {
|
||||
def := fmt.Sprintf("%v", defaultVal)
|
||||
if def != "<nil>" {
|
||||
col.Default = &def
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
if strings.TrimSpace(dbName) == "" {
|
||||
return nil, fmt.Errorf("database name required for GetAllColumns")
|
||||
}
|
||||
|
||||
tables, err := t.GetTables(dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cols := make([]connection.ColumnDefinitionWithTable, 0)
|
||||
for _, table := range tables {
|
||||
tableCols, err := t.GetColumns(dbName, table)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, col := range tableCols {
|
||||
cols = append(cols, connection.ColumnDefinitionWithTable{
|
||||
TableName: table,
|
||||
Name: col.Name,
|
||||
Type: col.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return []connection.IndexDefinition{}, nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
}
|
||||
|
||||
func (t *TDengineDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
}
|
||||
|
||||
func getValueFromRow(row map[string]interface{}, keys ...string) (interface{}, bool) {
|
||||
if len(row) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if val, ok := row[key]; ok {
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
|
||||
for existingKey, val := range row {
|
||||
for _, key := range keys {
|
||||
if strings.EqualFold(existingKey, key) {
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func escapeBacktickIdent(ident string) string {
|
||||
return strings.ReplaceAll(strings.TrimSpace(ident), "`", "``")
|
||||
}
|
||||
|
||||
func quoteTDengineTable(dbName, tableName string) string {
|
||||
t := escapeBacktickIdent(tableName)
|
||||
if t == "" {
|
||||
return "``"
|
||||
}
|
||||
if strings.Contains(t, ".") {
|
||||
parts := strings.Split(t, ".")
|
||||
quoted := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
quoted = append(quoted, fmt.Sprintf("`%s`", escapeBacktickIdent(part)))
|
||||
}
|
||||
if len(quoted) > 0 {
|
||||
return strings.Join(quoted, ".")
|
||||
}
|
||||
}
|
||||
|
||||
db := escapeBacktickIdent(dbName)
|
||||
if db == "" {
|
||||
return fmt.Sprintf("`%s`", t)
|
||||
}
|
||||
return fmt.Sprintf("`%s`.`%s`", db, t)
|
||||
}
|
||||
627
internal/db/vastbase_impl.go
Normal file
627
internal/db/vastbase_impl.go
Normal file
@@ -0,0 +1,627 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/lib/pq" // Vastbase is PostgreSQL compatible
|
||||
)
|
||||
|
||||
// VastbaseDB implements Database interface for Vastbase (海量) database
|
||||
// Vastbase is a PostgreSQL-compatible database, so we reuse PostgreSQL driver
|
||||
type VastbaseDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
dbname := config.Database
|
||||
if dbname == "" {
|
||||
dbname = "vastbase" // Vastbase default database
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "postgres",
|
||||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
Path: "/" + dbname,
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", "disable")
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
|
||||
if config.UseSSH {
|
||||
logger.Infof("Vastbase 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
v.forwarder = forwarder
|
||||
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = v.getDSN(localConfig)
|
||||
logger.Infof("Vastbase 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = v.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
v.conn = db
|
||||
v.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := v.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) Close() error {
|
||||
if v.forwarder != nil {
|
||||
if err := v.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 Vastbase SSH 端口转发失败:%v", err)
|
||||
}
|
||||
v.forwarder = nil
|
||||
}
|
||||
|
||||
if v.conn != nil {
|
||||
return v.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) Ping() error {
|
||||
if v.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
timeout := v.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return v.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if v.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := v.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if v.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := v.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if v.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := v.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) Exec(query string) (int64, error) {
|
||||
if v.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := v.conn.Exec(query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := v.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["datname"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetTables(dbName string) ([]string, error) {
|
||||
query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg_%' ORDER BY schemaname, tablename"
|
||||
data, _, err := v.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
schema, okSchema := row["schemaname"]
|
||||
name, okName := row["tablename"]
|
||||
if okSchema && okName {
|
||||
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
|
||||
continue
|
||||
}
|
||||
if okName {
|
||||
tables = append(tables, fmt.Sprintf("%v", name))
|
||||
}
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
return fmt.Sprintf("-- SHOW CREATE TABLE not fully supported for Vastbase in this version.\n-- Table: %s", tableName), nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
col_description(a.attrelid, a.attnum) AS comment,
|
||||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||||
LEFT JOIN (
|
||||
SELECT i.indrelid, a3.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||||
WHERE i.indisprimary
|
||||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||||
WHERE c.relkind IN ('r', 'p')
|
||||
AND n.nspname = '%s'
|
||||
AND c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum`, esc(schema), esc(table))
|
||||
|
||||
data, _, err := v.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if val, ok := row["comment"]; ok && val != nil {
|
||||
col.Comment = fmt.Sprintf("%v", val)
|
||||
}
|
||||
|
||||
if val, ok := row["column_default"]; ok && val != nil {
|
||||
def := fmt.Sprintf("%v", val)
|
||||
col.Default = &def
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||||
col.Extra = "auto_increment"
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname AS index_name,
|
||||
a.attname AS column_name,
|
||||
ix.indisunique AS is_unique,
|
||||
x.ordinality AS seq_in_index,
|
||||
am.amname AS index_type
|
||||
FROM pg_class t
|
||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||
JOIN pg_index ix ON t.oid = ix.indrelid
|
||||
JOIN pg_class i ON i.oid = ix.indexrelid
|
||||
JOIN pg_am am ON i.relam = am.oid
|
||||
JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE
|
||||
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
|
||||
WHERE t.relkind IN ('r', 'p')
|
||||
AND t.relname = '%s'
|
||||
AND n.nspname = '%s'
|
||||
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := v.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseBool := func(val interface{}) bool {
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
s := strings.ToLower(strings.TrimSpace(v))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
default:
|
||||
s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", val)))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
}
|
||||
}
|
||||
|
||||
parseInt := func(val interface{}) int {
|
||||
switch v := val.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(v), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", val)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
isUnique := false
|
||||
if val, ok := row["is_unique"]; ok && val != nil {
|
||||
isUnique = parseBool(val)
|
||||
}
|
||||
|
||||
nonUnique := 1
|
||||
if isUnique {
|
||||
nonUnique = 0
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if val, ok := row["seq_in_index"]; ok && val != nil {
|
||||
seq = parseInt(val)
|
||||
}
|
||||
|
||||
indexType := ""
|
||||
if val, ok := row["index_type"]; ok && val != nil {
|
||||
indexType = strings.ToUpper(fmt.Sprintf("%v", val))
|
||||
}
|
||||
if indexType == "" {
|
||||
indexType = "BTREE"
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["index_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: indexType,
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name AS constraint_name,
|
||||
kcu.column_name AS column_name,
|
||||
ccu.table_schema AS foreign_table_schema,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_name = '%s'
|
||||
AND tc.table_schema = '%s'
|
||||
ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := v.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
refSchema := ""
|
||||
if val, ok := row["foreign_table_schema"]; ok && val != nil {
|
||||
refSchema = fmt.Sprintf("%v", val)
|
||||
}
|
||||
refTable := fmt.Sprintf("%v", row["foreign_table_name"])
|
||||
refTableName := refTable
|
||||
if strings.TrimSpace(refSchema) != "" {
|
||||
refTableName = fmt.Sprintf("%s.%s", refSchema, refTable)
|
||||
}
|
||||
|
||||
fk := connection.ForeignKeyDefinition{
|
||||
Name: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
RefTableName: refTableName,
|
||||
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
|
||||
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT trigger_name, action_timing, event_manipulation, action_statement
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s'
|
||||
AND event_object_schema = '%s'
|
||||
ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := v.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
trig := connection.TriggerDefinition{
|
||||
Name: fmt.Sprintf("%v", row["trigger_name"]),
|
||||
Timing: fmt.Sprintf("%v", row["action_timing"]),
|
||||
Event: fmt.Sprintf("%v", row["event_manipulation"]),
|
||||
Statement: fmt.Sprintf("%v", row["action_statement"]),
|
||||
}
|
||||
triggers = append(triggers, trig)
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
query := `
|
||||
SELECT table_schema, table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
AND table_schema NOT LIKE 'pg_%'
|
||||
ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
data, _, err := v.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||||
table := fmt.Sprintf("%v", row["table_name"])
|
||||
tableName := table
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||||
}
|
||||
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: tableName,
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
}
|
||||
cols = append(cols, col)
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if v.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
tx, err := v.conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
quoteIdent := func(name string) string {
|
||||
n := strings.TrimSpace(name)
|
||||
n = strings.Trim(n, "\"")
|
||||
n = strings.ReplaceAll(n, "\"", "\"\"")
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
schema := ""
|
||||
table := strings.TrimSpace(tableName)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
qualifiedTable := ""
|
||||
if schema != "" {
|
||||
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
|
||||
} else {
|
||||
qualifiedTable = quoteIdent(table)
|
||||
}
|
||||
|
||||
// 1. Deletes
|
||||
for _, pk := range changes.Deletes {
|
||||
var wheres []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
for k, val := range pk {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
args = append(args, val)
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("delete error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Updates
|
||||
for _, update := range changes.Updates {
|
||||
var sets []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
|
||||
for k, val := range update.Values {
|
||||
idx++
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var wheres []string
|
||||
for k, val := range update.Keys {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("update error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Inserts
|
||||
for _, row := range changes.Inserts {
|
||||
var cols []string
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
idx := 0
|
||||
|
||||
for k, val := range row {
|
||||
idx++
|
||||
cols = append(cols, quoteIdent(k))
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
if _, err := tx.Exec(query, args...); err != nil {
|
||||
return fmt.Errorf("insert error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
@@ -22,8 +22,11 @@ func quoteIdentByType(dbType string, ident string) string {
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
case "mysql", "mariadb", "sphinx":
|
||||
return "`" + strings.ReplaceAll(ident, "`", "``") + "`"
|
||||
case "sqlserver":
|
||||
escaped := strings.ReplaceAll(ident, "]", "]]")
|
||||
return "[" + escaped + "]"
|
||||
default:
|
||||
return `"` + strings.ReplaceAll(ident, `"`, `""`) + `"`
|
||||
}
|
||||
@@ -71,7 +74,7 @@ func normalizeSchemaAndTable(dbType string, dbName string, tableName string) (st
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase":
|
||||
case "postgres", "kingbase", "vastbase":
|
||||
return "public", rawTable
|
||||
default:
|
||||
return rawDB, rawTable
|
||||
@@ -88,7 +91,7 @@ func qualifiedNameForQuery(dbType string, schema string, table string, original
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase":
|
||||
case "postgres", "kingbase", "vastbase":
|
||||
s := strings.TrimSpace(schema)
|
||||
if s == "" {
|
||||
s = "public"
|
||||
@@ -97,7 +100,7 @@ func qualifiedNameForQuery(dbType string, schema string, table string, original
|
||||
return raw
|
||||
}
|
||||
return s + "." + table
|
||||
case "mysql":
|
||||
case "mysql", "mariadb", "sphinx":
|
||||
s := strings.TrimSpace(schema)
|
||||
if s == "" || table == "" {
|
||||
return table
|
||||
|
||||
6
third_party/highgo-pq/.gitignore
vendored
Normal file
6
third_party/highgo-pq/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
.db
|
||||
*.test
|
||||
*~
|
||||
*.swp
|
||||
.idea
|
||||
.vscode
|
||||
8
third_party/highgo-pq/LICENSE.md
vendored
Normal file
8
third_party/highgo-pq/LICENSE.md
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
Copyright (c) 2011-2013, 'pq' Contributors
|
||||
Portions Copyright (C) 2011 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
36
third_party/highgo-pq/README.md
vendored
Normal file
36
third_party/highgo-pq/README.md
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
# pq - A pure Go postgres driver for Go's database/sql package
|
||||
|
||||
[](https://pkg.go.dev/github.com/lib/pq?tab=doc)
|
||||
|
||||
## Install
|
||||
|
||||
go get github.com/lib/pq
|
||||
|
||||
## Features
|
||||
|
||||
* SSL
|
||||
* Handles bad connections for `database/sql`
|
||||
* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`)
|
||||
* Scan binary blobs correctly (i.e. `bytea`)
|
||||
* Package for `hstore` support
|
||||
* COPY FROM support
|
||||
* pq.ParseURL for converting urls to connection strings for sql.Open.
|
||||
* Many libpq compatible environment variables
|
||||
* Unix socket support
|
||||
* Notifications: `LISTEN`/`NOTIFY`
|
||||
* pgpass support
|
||||
* GSS (Kerberos) auth
|
||||
|
||||
## Tests
|
||||
|
||||
`go test` is used for testing. See [TESTS.md](TESTS.md) for more details.
|
||||
|
||||
## Status
|
||||
|
||||
This package is currently in maintenance mode, which means:
|
||||
1. It generally does not accept new features.
|
||||
2. It does accept bug fixes and version compatability changes provided by the community.
|
||||
3. Maintainers usually do not resolve reported issues.
|
||||
4. Community members are encouraged to help each other with reported issues.
|
||||
|
||||
For users that require new features or reliable resolution of reported bugs, we recommend using [pgx](https://github.com/jackc/pgx) which is under active development.
|
||||
33
third_party/highgo-pq/TESTS.md
vendored
Normal file
33
third_party/highgo-pq/TESTS.md
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
`go test` is used for testing. A running PostgreSQL
|
||||
server is required, with the ability to log in. The
|
||||
database to connect to test with is "pqgotest," on
|
||||
"localhost" but these can be overridden using [environment
|
||||
variables](https://www.postgresql.org/docs/9.3/static/libpq-envars.html).
|
||||
|
||||
Example:
|
||||
|
||||
PGHOST=/run/postgresql go test
|
||||
|
||||
## Benchmarks
|
||||
|
||||
A benchmark suite can be run as part of the tests:
|
||||
|
||||
go test -bench .
|
||||
|
||||
## Example setup (Docker)
|
||||
|
||||
Run a postgres container:
|
||||
|
||||
```
|
||||
docker run --expose 5432:5432 postgres
|
||||
```
|
||||
|
||||
Run tests:
|
||||
|
||||
```
|
||||
PGHOST=localhost PGPORT=5432 PGUSER=postgres PGSSLMODE=disable PGDATABASE=postgres go test
|
||||
```
|
||||
895
third_party/highgo-pq/array.go
vendored
Normal file
895
third_party/highgo-pq/array.go
vendored
Normal file
@@ -0,0 +1,895 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var typeByteSlice = reflect.TypeOf([]byte{})
|
||||
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
|
||||
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
|
||||
// slice of any dimension.
|
||||
//
|
||||
// For example:
|
||||
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
|
||||
//
|
||||
// var x []sql.NullInt64
|
||||
// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x))
|
||||
//
|
||||
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
|
||||
// bound is not one (such as `[0:0]={1}') are not supported.
|
||||
func Array(a interface{}) interface {
|
||||
driver.Valuer
|
||||
sql.Scanner
|
||||
} {
|
||||
switch a := a.(type) {
|
||||
case []bool:
|
||||
return (*BoolArray)(&a)
|
||||
case []float64:
|
||||
return (*Float64Array)(&a)
|
||||
case []float32:
|
||||
return (*Float32Array)(&a)
|
||||
case []int64:
|
||||
return (*Int64Array)(&a)
|
||||
case []int32:
|
||||
return (*Int32Array)(&a)
|
||||
case []string:
|
||||
return (*StringArray)(&a)
|
||||
case [][]byte:
|
||||
return (*ByteaArray)(&a)
|
||||
|
||||
case *[]bool:
|
||||
return (*BoolArray)(a)
|
||||
case *[]float64:
|
||||
return (*Float64Array)(a)
|
||||
case *[]float32:
|
||||
return (*Float32Array)(a)
|
||||
case *[]int64:
|
||||
return (*Int64Array)(a)
|
||||
case *[]int32:
|
||||
return (*Int32Array)(a)
|
||||
case *[]string:
|
||||
return (*StringArray)(a)
|
||||
case *[][]byte:
|
||||
return (*ByteaArray)(a)
|
||||
}
|
||||
|
||||
return GenericArray{a}
|
||||
}
|
||||
|
||||
// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
|
||||
// to override the array delimiter used by GenericArray.
|
||||
type ArrayDelimiter interface {
|
||||
// ArrayDelimiter returns the delimiter character(s) for this element's type.
|
||||
ArrayDelimiter() string
|
||||
}
|
||||
|
||||
// BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
|
||||
type BoolArray []bool
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *BoolArray) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to BoolArray", src)
|
||||
}
|
||||
|
||||
func (a *BoolArray) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(BoolArray, len(elems))
|
||||
for i, v := range elems {
|
||||
if len(v) != 1 {
|
||||
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
|
||||
}
|
||||
switch v[0] {
|
||||
case 't':
|
||||
b[i] = true
|
||||
case 'f':
|
||||
b[i] = false
|
||||
default:
|
||||
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a BoolArray) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be exactly two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1+2*n)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
b[2*i] = ','
|
||||
if a[i] {
|
||||
b[1+2*i] = 't'
|
||||
} else {
|
||||
b[1+2*i] = 'f'
|
||||
}
|
||||
}
|
||||
|
||||
b[0] = '{'
|
||||
b[2*n] = '}'
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type.
|
||||
type ByteaArray [][]byte
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *ByteaArray) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to ByteaArray", src)
|
||||
}
|
||||
|
||||
func (a *ByteaArray) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "ByteaArray")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(ByteaArray, len(elems))
|
||||
for i, v := range elems {
|
||||
b[i], err = parseBytea(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface. It uses the "hex" format which
|
||||
// is only supported on PostgreSQL 9.0 or newer.
|
||||
func (a ByteaArray) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, 2*N bytes of quotes,
|
||||
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
|
||||
size := 1 + 6*n
|
||||
for _, x := range a {
|
||||
size += hex.EncodedLen(len(x))
|
||||
}
|
||||
|
||||
b := make([]byte, size)
|
||||
|
||||
for i, s := 0, b; i < n; i++ {
|
||||
o := copy(s, `,"\\x`)
|
||||
o += hex.Encode(s[o:], a[i])
|
||||
s[o] = '"'
|
||||
s = s[o+1:]
|
||||
}
|
||||
|
||||
b[0] = '{'
|
||||
b[size-1] = '}'
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Float64Array represents a one-dimensional array of the PostgreSQL double
|
||||
// precision type.
|
||||
type Float64Array []float64
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Float64Array) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to Float64Array", src)
|
||||
}
|
||||
|
||||
func (a *Float64Array) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Float64Array, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Float64Array) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Float32Array represents a one-dimensional array of the PostgreSQL double
|
||||
// precision type.
|
||||
type Float32Array []float32
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Float32Array) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to Float32Array", src)
|
||||
}
|
||||
|
||||
func (a *Float32Array) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Float32Array")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Float32Array, len(elems))
|
||||
for i, v := range elems {
|
||||
var x float64
|
||||
if x, err = strconv.ParseFloat(string(v), 32); err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
b[i] = float32(x)
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Float32Array) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// GenericArray implements the driver.Valuer and sql.Scanner interfaces for
|
||||
// an array or slice of any dimension.
|
||||
type GenericArray struct{ A interface{} }
|
||||
|
||||
func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
|
||||
var assign func([]byte, reflect.Value) error
|
||||
var del = ","
|
||||
|
||||
// TODO calculate the assign function for other types
|
||||
// TODO repeat this section on the element type of arrays or slices (multidimensional)
|
||||
{
|
||||
if reflect.PtrTo(rt).Implements(typeSQLScanner) {
|
||||
// dest is always addressable because it is an element of a slice.
|
||||
assign = func(src []byte, dest reflect.Value) (err error) {
|
||||
ss := dest.Addr().Interface().(sql.Scanner)
|
||||
if src == nil {
|
||||
err = ss.Scan(nil)
|
||||
} else {
|
||||
err = ss.Scan(src)
|
||||
}
|
||||
return
|
||||
}
|
||||
goto FoundType
|
||||
}
|
||||
|
||||
assign = func([]byte, reflect.Value) error {
|
||||
return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt)
|
||||
}
|
||||
}
|
||||
|
||||
FoundType:
|
||||
|
||||
if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok {
|
||||
del = ad.ArrayDelimiter()
|
||||
}
|
||||
|
||||
return rt, assign, del
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a GenericArray) Scan(src interface{}) error {
|
||||
dpv := reflect.ValueOf(a.A)
|
||||
switch {
|
||||
case dpv.Kind() != reflect.Ptr:
|
||||
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
|
||||
case dpv.IsNil():
|
||||
return fmt.Errorf("pq: destination %T is nil", a.A)
|
||||
}
|
||||
|
||||
dv := dpv.Elem()
|
||||
switch dv.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Array:
|
||||
default:
|
||||
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src, dv)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src), dv)
|
||||
case nil:
|
||||
if dv.Kind() == reflect.Slice {
|
||||
dv.Set(reflect.Zero(dv.Type()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type())
|
||||
}
|
||||
|
||||
func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error {
|
||||
dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
|
||||
dims, elems, err := parseArray(src, []byte(del))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO allow multidimensional
|
||||
|
||||
if len(dims) > 1 {
|
||||
return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented",
|
||||
strings.Replace(fmt.Sprint(dims), " ", "][", -1))
|
||||
}
|
||||
|
||||
// Treat a zero-dimensional array like an array with a single dimension of zero.
|
||||
if len(dims) == 0 {
|
||||
dims = append(dims, 0)
|
||||
}
|
||||
|
||||
for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
|
||||
switch rt.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Array:
|
||||
if rt.Len() != dims[i] {
|
||||
return fmt.Errorf("pq: cannot convert ARRAY%s to %s",
|
||||
strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
|
||||
}
|
||||
default:
|
||||
// TODO handle multidimensional
|
||||
}
|
||||
}
|
||||
|
||||
values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
|
||||
for i, e := range elems {
|
||||
if err := assign(e, values.Index(i)); err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO handle multidimensional
|
||||
|
||||
switch dv.Kind() {
|
||||
case reflect.Slice:
|
||||
dv.Set(values.Slice(0, dims[0]))
|
||||
case reflect.Array:
|
||||
for i := 0; i < dims[0]; i++ {
|
||||
dv.Index(i).Set(values.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a GenericArray) Value() (driver.Value, error) {
|
||||
if a.A == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(a.A)
|
||||
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice:
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
case reflect.Array:
|
||||
default:
|
||||
return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A)
|
||||
}
|
||||
|
||||
if n := rv.Len(); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 0, 1+2*n)
|
||||
|
||||
b, _, err := appendArray(b, rv, n)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Int64Array represents a one-dimensional array of the PostgreSQL integer types.
|
||||
type Int64Array []int64
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Int64Array) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to Int64Array", src)
|
||||
}
|
||||
|
||||
func (a *Int64Array) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Int64Array, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Int64Array) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendInt(b, a[0], 10)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendInt(b, a[i], 10)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Int32Array represents a one-dimensional array of the PostgreSQL integer types.
|
||||
type Int32Array []int32
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Int32Array) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to Int32Array", src)
|
||||
}
|
||||
|
||||
func (a *Int32Array) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Int32Array")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Int32Array, len(elems))
|
||||
for i, v := range elems {
|
||||
x, err := strconv.ParseInt(string(v), 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
b[i] = int32(x)
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Int32Array) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendInt(b, int64(a[0]), 10)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendInt(b, int64(a[i]), 10)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// StringArray represents a one-dimensional array of the PostgreSQL character types.
|
||||
type StringArray []string
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *StringArray) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
case nil:
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to StringArray", src)
|
||||
}
|
||||
|
||||
func (a *StringArray) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "StringArray")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *a != nil && len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(StringArray, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i] = string(v); v == nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a StringArray) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, 2*N bytes of quotes,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+3*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = appendArrayQuotedBytes(b, []byte(a[0]))
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = appendArrayQuotedBytes(b, []byte(a[i]))
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// appendArray appends rv to the buffer, returning the extended buffer and
|
||||
// the delimiter used between elements.
|
||||
//
|
||||
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
|
||||
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
|
||||
var del string
|
||||
var err error
|
||||
|
||||
b = append(b, '{')
|
||||
|
||||
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, del...)
|
||||
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
}
|
||||
|
||||
return append(b, '}'), del, nil
|
||||
}
|
||||
|
||||
// appendArrayElement appends rv to the buffer, returning the extended buffer
|
||||
// and the delimiter to use before the next element.
|
||||
//
|
||||
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
|
||||
// using driver.DefaultParameterConverter and the resulting []byte or string
|
||||
// is double-quoted.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
|
||||
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
|
||||
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
|
||||
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
|
||||
if n := rv.Len(); n > 0 {
|
||||
return appendArray(b, rv, n)
|
||||
}
|
||||
|
||||
return b, "", nil
|
||||
}
|
||||
}
|
||||
|
||||
var del = ","
|
||||
var err error
|
||||
var iv interface{} = rv.Interface()
|
||||
|
||||
if ad, ok := iv.(ArrayDelimiter); ok {
|
||||
del = ad.ArrayDelimiter()
|
||||
}
|
||||
|
||||
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
switch v := iv.(type) {
|
||||
case nil:
|
||||
return append(b, "NULL"...), del, nil
|
||||
case []byte:
|
||||
return appendArrayQuotedBytes(b, v), del, nil
|
||||
case string:
|
||||
return appendArrayQuotedBytes(b, []byte(v)), del, nil
|
||||
}
|
||||
|
||||
b, err = appendValue(b, iv)
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
func appendArrayQuotedBytes(b, v []byte) []byte {
|
||||
b = append(b, '"')
|
||||
for {
|
||||
i := bytes.IndexAny(v, `"\`)
|
||||
if i < 0 {
|
||||
b = append(b, v...)
|
||||
break
|
||||
}
|
||||
if i > 0 {
|
||||
b = append(b, v[:i]...)
|
||||
}
|
||||
b = append(b, '\\', v[i])
|
||||
v = v[i+1:]
|
||||
}
|
||||
return append(b, '"')
|
||||
}
|
||||
|
||||
func appendValue(b []byte, v driver.Value) ([]byte, error) {
|
||||
return append(b, encode(nil, v, 0)...), nil
|
||||
}
|
||||
|
||||
// parseArray extracts the dimensions and elements of an array represented in
|
||||
// text format. Only representations emitted by the backend are supported.
|
||||
// Notably, whitespace around brackets and delimiters is significant, and NULL
|
||||
// is case-sensitive.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
|
||||
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
|
||||
var depth, i int
|
||||
|
||||
if len(src) < 1 || src[0] != '{' {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0)
|
||||
}
|
||||
|
||||
Open:
|
||||
for i < len(src) {
|
||||
switch src[i] {
|
||||
case '{':
|
||||
depth++
|
||||
i++
|
||||
case '}':
|
||||
elems = make([][]byte, 0)
|
||||
goto Close
|
||||
default:
|
||||
break Open
|
||||
}
|
||||
}
|
||||
dims = make([]int, i)
|
||||
|
||||
Element:
|
||||
for i < len(src) {
|
||||
switch src[i] {
|
||||
case '{':
|
||||
if depth == len(dims) {
|
||||
break Element
|
||||
}
|
||||
depth++
|
||||
dims[depth-1] = 0
|
||||
i++
|
||||
case '"':
|
||||
var elem = []byte{}
|
||||
var escape bool
|
||||
for i++; i < len(src); i++ {
|
||||
if escape {
|
||||
elem = append(elem, src[i])
|
||||
escape = false
|
||||
} else {
|
||||
switch src[i] {
|
||||
default:
|
||||
elem = append(elem, src[i])
|
||||
case '\\':
|
||||
escape = true
|
||||
case '"':
|
||||
elems = append(elems, elem)
|
||||
i++
|
||||
break Element
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
for start := i; i < len(src); i++ {
|
||||
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
|
||||
elem := src[start:i]
|
||||
if len(elem) == 0 {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
if bytes.Equal(elem, []byte("NULL")) {
|
||||
elem = nil
|
||||
}
|
||||
elems = append(elems, elem)
|
||||
break Element
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i < len(src) {
|
||||
if bytes.HasPrefix(src[i:], del) && depth > 0 {
|
||||
dims[depth-1]++
|
||||
i += len(del)
|
||||
goto Element
|
||||
} else if src[i] == '}' && depth > 0 {
|
||||
dims[depth-1]++
|
||||
depth--
|
||||
i++
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
}
|
||||
|
||||
Close:
|
||||
for i < len(src) {
|
||||
if src[i] == '}' && depth > 0 {
|
||||
depth--
|
||||
i++
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
}
|
||||
if depth > 0 {
|
||||
err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i)
|
||||
}
|
||||
if err == nil {
|
||||
for _, d := range dims {
|
||||
if (len(elems) % d) != 0 {
|
||||
err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions")
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
|
||||
dims, elems, err := parseArray(src, del)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) > 1 {
|
||||
return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
|
||||
}
|
||||
return elems, err
|
||||
}
|
||||
1652
third_party/highgo-pq/array_test.go
vendored
Normal file
1652
third_party/highgo-pq/array_test.go
vendored
Normal file
File diff suppressed because it is too large
Load Diff
8
third_party/highgo-pq/auth/kerberos/go.mod
vendored
Normal file
8
third_party/highgo-pq/auth/kerberos/go.mod
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
module github.com/lib/pq/auth/kerberos
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5
|
||||
github.com/jcmturner/gokrb5/v8 v8.2.0
|
||||
)
|
||||
40
third_party/highgo-pq/auth/kerberos/go.sum
vendored
Normal file
40
third_party/highgo-pq/auth/kerberos/go.sum
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 h1:P5U+E4x5OkVEKQDklVPmzs71WM56RTTRqV4OrDC//Y4=
|
||||
github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5/go.mod h1:976q2ETgjT2snVCf2ZaBnyBbVoPERGjUz+0sofzEfro=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ=
|
||||
github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
|
||||
github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8=
|
||||
github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
|
||||
github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o=
|
||||
github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
|
||||
github.com/jcmturner/gokrb5/v8 v8.2.0 h1:lzPl/30ZLkTveYsYZPKMcgXc8MbnE6RsTd4F9KgiLtk=
|
||||
github.com/jcmturner/gokrb5/v8 v8.2.0/go.mod h1:T1hnNppQsBtxW0tCHMHTkAt8n/sABdzZgZdoFrZaZNM=
|
||||
github.com/jcmturner/rpc/v2 v2.0.2 h1:gMB4IwRXYsWw4Bc6o/az2HJgFUA1ffSh90i26ZJ6Xl0=
|
||||
github.com/jcmturner/rpc/v2 v2.0.2/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200117160349-530e935923ad h1:Jh8cai0fqIK+f6nG0UgPW5wFk8wmiMhM3AyciDBdtQg=
|
||||
golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
29
third_party/highgo-pq/auth/kerberos/krb.go
vendored
Normal file
29
third_party/highgo-pq/auth/kerberos/krb.go
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package kerberos
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
/*
|
||||
* Find the A record associated with a hostname
|
||||
* In general, hostnames supplied to the driver should be
|
||||
* canonicalized because the KDC usually only has one
|
||||
* principal and not one per potential alias of a host.
|
||||
*/
|
||||
func canonicalizeHostname(host string) (string, error) {
|
||||
canon := host
|
||||
|
||||
name, err := net.LookupCNAME(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
name = strings.TrimSuffix(name, ".")
|
||||
|
||||
if name != "" {
|
||||
canon = name
|
||||
}
|
||||
|
||||
return canon, nil
|
||||
}
|
||||
128
third_party/highgo-pq/auth/kerberos/krb_unix.go
vendored
Normal file
128
third_party/highgo-pq/auth/kerberos/krb_unix.go
vendored
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package kerberos
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/jcmturner/gokrb5/v8/client"
|
||||
"github.com/jcmturner/gokrb5/v8/config"
|
||||
"github.com/jcmturner/gokrb5/v8/credentials"
|
||||
"github.com/jcmturner/gokrb5/v8/spnego"
|
||||
)
|
||||
|
||||
/*
|
||||
* UNIX Kerberos support, using jcmturner's pure-go
|
||||
* implementation
|
||||
*/
|
||||
|
||||
// GSS implements the pq.GSS interface.
|
||||
type GSS struct {
|
||||
cli *client.Client
|
||||
}
|
||||
|
||||
// NewGSS creates a new GSS provider.
|
||||
func NewGSS() (*GSS, error) {
|
||||
g := &GSS{}
|
||||
err := g.init()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (g *GSS) init() error {
|
||||
cfgPath, ok := os.LookupEnv("KRB5_CONFIG")
|
||||
if !ok {
|
||||
cfgPath = "/etc/krb5.conf"
|
||||
}
|
||||
|
||||
cfg, err := config.Load(cfgPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ccpath := "/tmp/krb5cc_" + u.Uid
|
||||
|
||||
ccname := os.Getenv("KRB5CCNAME")
|
||||
if strings.HasPrefix(ccname, "FILE:") {
|
||||
ccpath = strings.SplitN(ccname, ":", 2)[1]
|
||||
}
|
||||
|
||||
ccache, err := credentials.LoadCCache(ccpath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cl, err := client.NewFromCCache(ccache, cfg, client.DisablePAFXFAST(true))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cl.Login()
|
||||
|
||||
g.cli = cl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInitToken implements the GSS interface.
|
||||
func (g *GSS) GetInitToken(host string, service string) ([]byte, error) {
|
||||
|
||||
// Resolve the hostname down to an 'A' record, if required (usually, it is)
|
||||
if g.cli.Config.LibDefaults.DNSCanonicalizeHostname {
|
||||
var err error
|
||||
host, err = canonicalizeHostname(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
spn := service + "/" + host
|
||||
|
||||
return g.GetInitTokenFromSpn(spn)
|
||||
}
|
||||
|
||||
// GetInitTokenFromSpn implements the GSS interface.
|
||||
func (g *GSS) GetInitTokenFromSpn(spn string) ([]byte, error) {
|
||||
s := spnego.SPNEGOClient(g.cli, spn)
|
||||
|
||||
st, err := s.InitSecContext()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kerberos error (InitSecContext): %s", err.Error())
|
||||
}
|
||||
|
||||
b, err := st.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kerberos error (Marshaling token): %s", err.Error())
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Continue implements the GSS interface.
|
||||
func (g *GSS) Continue(inToken []byte) (done bool, outToken []byte, err error) {
|
||||
t := &spnego.SPNEGOToken{}
|
||||
err = t.Unmarshal(inToken)
|
||||
if err != nil {
|
||||
return true, nil, fmt.Errorf("kerberos error (Unmarshaling token): %s", err.Error())
|
||||
}
|
||||
|
||||
state := t.NegTokenResp.State()
|
||||
if state != spnego.NegStateAcceptCompleted {
|
||||
return true, nil, fmt.Errorf("kerberos: expected state 'Completed' - got %d", state)
|
||||
}
|
||||
|
||||
return true, nil, nil
|
||||
}
|
||||
67
third_party/highgo-pq/auth/kerberos/krb_windows.go
vendored
Normal file
67
third_party/highgo-pq/auth/kerberos/krb_windows.go
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package kerberos
|
||||
|
||||
import (
|
||||
"github.com/alexbrainman/sspi"
|
||||
"github.com/alexbrainman/sspi/negotiate"
|
||||
)
|
||||
|
||||
// GSS implements the pq.GSS interface.
|
||||
type GSS struct {
|
||||
creds *sspi.Credentials
|
||||
ctx *negotiate.ClientContext
|
||||
}
|
||||
|
||||
// NewGSS creates a new GSS provider.
|
||||
func NewGSS() (*GSS, error) {
|
||||
g := &GSS{}
|
||||
err := g.init()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (g *GSS) init() error {
|
||||
creds, err := negotiate.AcquireCurrentUserCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.creds = creds
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInitToken implements the GSS interface.
|
||||
func (g *GSS) GetInitToken(host string, service string) ([]byte, error) {
|
||||
|
||||
host, err := canonicalizeHostname(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
spn := service + "/" + host
|
||||
|
||||
return g.GetInitTokenFromSpn(spn)
|
||||
}
|
||||
|
||||
// GetInitTokenFromSpn implements the GSS interface.
|
||||
func (g *GSS) GetInitTokenFromSpn(spn string) ([]byte, error) {
|
||||
ctx, token, err := negotiate.NewClientContext(g.creds, spn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g.ctx = ctx
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Continue implements the GSS interface.
|
||||
func (g *GSS) Continue(inToken []byte) (done bool, outToken []byte, err error) {
|
||||
return g.ctx.Update(inToken)
|
||||
}
|
||||
434
third_party/highgo-pq/bench_test.go
vendored
Normal file
434
third_party/highgo-pq/bench_test.go
vendored
Normal file
@@ -0,0 +1,434 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
var (
|
||||
selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'"
|
||||
selectSeriesQuery = "SELECT generate_series(1, 100)"
|
||||
)
|
||||
|
||||
func BenchmarkSelectString(b *testing.B) {
|
||||
var result string
|
||||
benchQuery(b, selectStringQuery, &result)
|
||||
}
|
||||
|
||||
func BenchmarkSelectSeries(b *testing.B) {
|
||||
var result int
|
||||
benchQuery(b, selectSeriesQuery, &result)
|
||||
}
|
||||
|
||||
func benchQuery(b *testing.B, query string, result interface{}) {
|
||||
b.StopTimer()
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchQueryLoop(b, db, query, result)
|
||||
}
|
||||
}
|
||||
|
||||
func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) {
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
err = rows.Scan(result)
|
||||
if err != nil {
|
||||
b.Fatal("failed to scan", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reading from circularConn yields content[:prefixLen] once, followed by
|
||||
// content[prefixLen:] over and over again. It never returns EOF.
|
||||
type circularConn struct {
|
||||
content string
|
||||
prefixLen int
|
||||
pos int
|
||||
net.Conn // for all other net.Conn methods that will never be called
|
||||
}
|
||||
|
||||
func (r *circularConn) Read(b []byte) (n int, err error) {
|
||||
n = copy(b, r.content[r.pos:])
|
||||
r.pos += n
|
||||
if r.pos >= len(r.content) {
|
||||
r.pos = r.prefixLen
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
|
||||
func (r *circularConn) Close() error { return nil }
|
||||
|
||||
func fakeConn(content string, prefixLen int) *conn {
|
||||
c := &circularConn{content: content, prefixLen: prefixLen}
|
||||
return &conn{buf: bufio.NewReader(c), c: c}
|
||||
}
|
||||
|
||||
// This benchmark is meant to be the same as BenchmarkSelectString, but takes
|
||||
// out some of the factors this package can't control. The numbers are less noisy,
|
||||
// but also the costs of network communication aren't accurately represented.
|
||||
func BenchmarkMockSelectString(b *testing.B) {
|
||||
b.StopTimer()
|
||||
// taken from a recorded run of BenchmarkSelectString
|
||||
// See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html
|
||||
const response = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"2\x00\x00\x00\x04" +
|
||||
"D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +
|
||||
"C\x00\x00\x00\rSELECT 1\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"3\x00\x00\x00\x04" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(response, 0)
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchMockQuery(b, c, selectStringQuery)
|
||||
}
|
||||
}
|
||||
|
||||
var seriesRowData = func() string {
|
||||
var buf bytes.Buffer
|
||||
for i := 1; i <= 100; i++ {
|
||||
digits := byte(2)
|
||||
if i >= 100 {
|
||||
digits = 3
|
||||
} else if i < 10 {
|
||||
digits = 1
|
||||
}
|
||||
buf.WriteString("D\x00\x00\x00")
|
||||
buf.WriteByte(10 + digits)
|
||||
buf.WriteString("\x00\x01\x00\x00\x00")
|
||||
buf.WriteByte(digits)
|
||||
buf.WriteString(strconv.Itoa(i))
|
||||
}
|
||||
return buf.String()
|
||||
}()
|
||||
|
||||
func BenchmarkMockSelectSeries(b *testing.B) {
|
||||
b.StopTimer()
|
||||
var response = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"2\x00\x00\x00\x04" +
|
||||
seriesRowData +
|
||||
"C\x00\x00\x00\x0fSELECT 100\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"3\x00\x00\x00\x04" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(response, 0)
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchMockQuery(b, c, selectSeriesQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func benchMockQuery(b *testing.B, c *conn, query string) {
|
||||
stmt, err := c.Prepare(query)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var dest [1]driver.Value
|
||||
for {
|
||||
if err := rows.Next(dest[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPreparedSelectString(b *testing.B) {
|
||||
var result string
|
||||
benchPreparedQuery(b, selectStringQuery, &result)
|
||||
}
|
||||
|
||||
func BenchmarkPreparedSelectSeries(b *testing.B) {
|
||||
var result int
|
||||
benchPreparedQuery(b, selectSeriesQuery, &result)
|
||||
}
|
||||
|
||||
func benchPreparedQuery(b *testing.B, query string, result interface{}) {
|
||||
b.StopTimer()
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchPreparedQueryLoop(b, db, stmt, result)
|
||||
}
|
||||
}
|
||||
|
||||
func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) {
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
rows.Close()
|
||||
b.Fatal("no rows")
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&result)
|
||||
if err != nil {
|
||||
b.Fatal("failed to scan")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See the comment for BenchmarkMockSelectString.
|
||||
func BenchmarkMockPreparedSelectString(b *testing.B) {
|
||||
b.StopTimer()
|
||||
const parseResponse = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
const responses = parseResponse +
|
||||
"2\x00\x00\x00\x04" +
|
||||
"D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +
|
||||
"C\x00\x00\x00\rSELECT 1\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(responses, len(parseResponse))
|
||||
|
||||
stmt, err := c.Prepare(selectStringQuery)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchPreparedMockQuery(b, c, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMockPreparedSelectSeries(b *testing.B) {
|
||||
b.StopTimer()
|
||||
const parseResponse = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
var responses = parseResponse +
|
||||
"2\x00\x00\x00\x04" +
|
||||
seriesRowData +
|
||||
"C\x00\x00\x00\x0fSELECT 100\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(responses, len(parseResponse))
|
||||
|
||||
stmt, err := c.Prepare(selectSeriesQuery)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchPreparedMockQuery(b, c, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) {
|
||||
rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var dest [1]driver.Value
|
||||
for {
|
||||
if err := rows.Next(dest[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeInt64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, int64(1234), oid.T_int8)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeFloat64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, 3.14159, oid.T_float8)
|
||||
}
|
||||
}
|
||||
|
||||
var testByteString = []byte("abcdefghijklmnopqrstuvwxyz")
|
||||
|
||||
func BenchmarkEncodeByteaHex(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{serverVersion: 90000}, testByteString, oid.T_bytea)
|
||||
}
|
||||
}
|
||||
func BenchmarkEncodeByteaEscape(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{serverVersion: 84000}, testByteString, oid.T_bytea)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeBool(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, true, oid.T_bool)
|
||||
}
|
||||
}
|
||||
|
||||
var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local)
|
||||
|
||||
func BenchmarkEncodeTimestamptz(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, testTimestamptz, oid.T_timestamptz)
|
||||
}
|
||||
}
|
||||
|
||||
var testIntBytes = []byte("1234")
|
||||
|
||||
func BenchmarkDecodeInt64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testIntBytes, oid.T_int8, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
var testFloatBytes = []byte("3.14159")
|
||||
|
||||
func BenchmarkDecodeFloat64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testFloatBytes, oid.T_float8, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
var testBoolBytes = []byte{'t'}
|
||||
|
||||
func BenchmarkDecodeBool(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testBoolBytes, oid.T_bool, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBool(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
rows, err := db.Query("select true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
|
||||
var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07")
|
||||
|
||||
func BenchmarkDecodeTimestamptz(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) {
|
||||
oldProcs := runtime.GOMAXPROCS(0)
|
||||
defer runtime.GOMAXPROCS(oldProcs)
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
globalLocationCache = newLocationCache()
|
||||
|
||||
f := func(wg *sync.WaitGroup, loops int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < loops; i++ {
|
||||
decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
b.ResetTimer()
|
||||
for j := 0; j < 10; j++ {
|
||||
wg.Add(1)
|
||||
go f(wg, b.N/10)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkLocationCache(b *testing.B) {
|
||||
globalLocationCache = newLocationCache()
|
||||
for i := 0; i < b.N; i++ {
|
||||
globalLocationCache.getLocation(rand.Intn(10000))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLocationCacheMultiThread(b *testing.B) {
|
||||
oldProcs := runtime.GOMAXPROCS(0)
|
||||
defer runtime.GOMAXPROCS(oldProcs)
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
globalLocationCache = newLocationCache()
|
||||
|
||||
f := func(wg *sync.WaitGroup, loops int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < loops; i++ {
|
||||
globalLocationCache.getLocation(rand.Intn(10000))
|
||||
}
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
b.ResetTimer()
|
||||
for j := 0; j < 10; j++ {
|
||||
wg.Add(1)
|
||||
go f(wg, b.N/10)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Stress test the performance of parsing results from the wire.
|
||||
func BenchmarkResultParsing(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
_, err := db.Exec("BEGIN")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, err := db.Query("SELECT generate_series(1, 50000)")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
res.Close()
|
||||
}
|
||||
}
|
||||
91
third_party/highgo-pq/buf.go
vendored
Normal file
91
third_party/highgo-pq/buf.go
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
type readBuf []byte
|
||||
|
||||
func (b *readBuf) int32() (n int) {
|
||||
n = int(int32(binary.BigEndian.Uint32(*b)))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) oid() (n oid.Oid) {
|
||||
n = oid.Oid(binary.BigEndian.Uint32(*b))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
// N.B: this is actually an unsigned 16-bit integer, unlike int32
|
||||
func (b *readBuf) int16() (n int) {
|
||||
n = int(binary.BigEndian.Uint16(*b))
|
||||
*b = (*b)[2:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) string() string {
|
||||
i := bytes.IndexByte(*b, 0)
|
||||
if i < 0 {
|
||||
errorf("invalid message format; expected string terminator")
|
||||
}
|
||||
s := (*b)[:i]
|
||||
*b = (*b)[i+1:]
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func (b *readBuf) next(n int) (v []byte) {
|
||||
v = (*b)[:n]
|
||||
*b = (*b)[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) byte() byte {
|
||||
return b.next(1)[0]
|
||||
}
|
||||
|
||||
type writeBuf struct {
|
||||
buf []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (b *writeBuf) int32(n int) {
|
||||
x := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(x, uint32(n))
|
||||
b.buf = append(b.buf, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) int16(n int) {
|
||||
x := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(x, uint16(n))
|
||||
b.buf = append(b.buf, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) string(s string) {
|
||||
b.buf = append(append(b.buf, s...), '\000')
|
||||
}
|
||||
|
||||
func (b *writeBuf) byte(c byte) {
|
||||
b.buf = append(b.buf, c)
|
||||
}
|
||||
|
||||
func (b *writeBuf) bytes(v []byte) {
|
||||
b.buf = append(b.buf, v...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) wrap() []byte {
|
||||
p := b.buf[b.pos:]
|
||||
binary.BigEndian.PutUint32(p, uint32(len(p)))
|
||||
return b.buf
|
||||
}
|
||||
|
||||
func (b *writeBuf) next(c byte) {
|
||||
p := b.buf[b.pos:]
|
||||
binary.BigEndian.PutUint32(p, uint32(len(p)))
|
||||
b.pos = len(b.buf) + 1
|
||||
b.buf = append(b.buf, c, 0, 0, 0, 0)
|
||||
}
|
||||
16
third_party/highgo-pq/buf_test.go
vendored
Normal file
16
third_party/highgo-pq/buf_test.go
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
package pq
|
||||
|
||||
import "testing"
|
||||
|
||||
func Benchmark_writeBuf_string(b *testing.B) {
|
||||
var buf writeBuf
|
||||
const s = "foo"
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf.string(s)
|
||||
buf.buf = buf.buf[:0]
|
||||
}
|
||||
}
|
||||
37
third_party/highgo-pq/certs/Makefile
vendored
Normal file
37
third_party/highgo-pq/certs/Makefile
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
.PHONY: all root-ssl server-ssl client-ssl
|
||||
|
||||
# Rebuilds self-signed root/server/client certs/keys in a consistent way
|
||||
all: root-ssl server-ssl client-ssl
|
||||
rm -f .srl
|
||||
|
||||
root-ssl:
|
||||
openssl req -new -sha256 -nodes -newkey rsa:2048 \
|
||||
-config ./certs/root.cnf \
|
||||
-keyout /tmp/root.key \
|
||||
-out /tmp/root.csr
|
||||
openssl x509 -req -days 3653 -sha256 \
|
||||
-in /tmp/root.csr \
|
||||
-extfile /etc/ssl/openssl.cnf -extensions v3_ca \
|
||||
-signkey /tmp/root.key \
|
||||
-out ./certs/root.crt
|
||||
|
||||
server-ssl:
|
||||
openssl req -new -sha256 -nodes -newkey rsa:2048 \
|
||||
-config ./certs/server.cnf \
|
||||
-keyout ./certs/server.key \
|
||||
-out /tmp/server.csr
|
||||
openssl x509 -req -days 3653 -sha256 \
|
||||
-extfile ./certs/server.cnf -extensions req_ext \
|
||||
-CA ./certs/root.crt -CAkey /tmp/root.key -CAcreateserial \
|
||||
-in /tmp/server.csr \
|
||||
-out ./certs/server.crt
|
||||
|
||||
client-ssl:
|
||||
openssl req -new -sha256 -nodes -newkey rsa:2048 \
|
||||
-config ./certs/postgresql.cnf \
|
||||
-keyout ./certs/postgresql.key \
|
||||
-out /tmp/postgresql.csr
|
||||
openssl x509 -req -days 3653 -sha256 \
|
||||
-CA ./certs/root.crt -CAkey /tmp/root.key -CAcreateserial \
|
||||
-in /tmp/postgresql.csr \
|
||||
-out ./certs/postgresql.crt
|
||||
3
third_party/highgo-pq/certs/README
vendored
Normal file
3
third_party/highgo-pq/certs/README
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
This directory contains certificates and private keys for testing some
|
||||
SSL-related functionality in Travis. Do NOT use these certificates for
|
||||
anything other than testing.
|
||||
19
third_party/highgo-pq/certs/bogus_root.crt
vendored
Normal file
19
third_party/highgo-pq/certs/bogus_root.crt
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDBjCCAe6gAwIBAgIQSnDYp/Naet9HOZljF5PuwDANBgkqhkiG9w0BAQsFADAr
|
||||
MRIwEAYDVQQKEwlDb2Nrcm9hY2gxFTATBgNVBAMTDENvY2tyb2FjaCBDQTAeFw0x
|
||||
NjAyMDcxNjQ0MzdaFw0xNzAyMDYxNjQ0MzdaMCsxEjAQBgNVBAoTCUNvY2tyb2Fj
|
||||
aDEVMBMGA1UEAxMMQ29ja3JvYWNoIENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
|
||||
MIIBCgKCAQEAxdln3/UdgP7ayA/G1kT7upjLe4ERwQjYQ25q0e1+vgsB5jhiirxJ
|
||||
e0+WkhhYu/mwoSAXzvlsbZ2PWFyfdanZeD/Lh6SvIeWXVVaPcWVWL1TEcoN2jr5+
|
||||
E85MMHmbbmaT2he8s6br2tM/UZxyTQ2XRprIzApbDssyw1c0Yufcpu3C6267FLEl
|
||||
IfcWrzDhnluFhthhtGXv3ToD8IuMScMC5qlKBXtKmD1B5x14ngO/ecNJ+OlEi0HU
|
||||
mavK4KWgI2rDXRZ2EnCpyTZdkc3kkRnzKcg653oOjMDRZdrhfIrha+Jq38ACsUmZ
|
||||
Su7Sp5jkIHOCO8Zg+l6GKVSq37dKMapD8wIDAQABoyYwJDAOBgNVHQ8BAf8EBAMC
|
||||
AuQwEgYDVR0TAQH/BAgwBgEB/wIBATANBgkqhkiG9w0BAQsFAAOCAQEAwZ2Tu0Yu
|
||||
rrSVdMdoPEjT1IZd+5OhM/SLzL0ddtvTithRweLHsw2lDQYlXFqr24i3UGZJQ1sp
|
||||
cqSrNwswgLUQT3vWyTjmM51HEb2vMYWKmjZ+sBQYAUP1CadrN/+OTfNGnlF1+B4w
|
||||
IXOzh7EvQmJJnNybLe4a/aRvj1NE2n8Z898B76SVU9WbfKKz8VwLzuIPDqkKcZda
|
||||
lMy5yzthyztV9YjcWs2zVOUGZvGdAhDrvZuUq6mSmxrBEvR2LBOggmVf3tGRT+Ls
|
||||
lW7c9Lrva5zLHuqmoPP07A+vuI9a0D1X44jwGDuPWJ5RnTOQ63Uez12mKNjqleHw
|
||||
DnkwNanuO8dhAA==
|
||||
-----END CERTIFICATE-----
|
||||
10
third_party/highgo-pq/certs/postgresql.cnf
vendored
Normal file
10
third_party/highgo-pq/certs/postgresql.cnf
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
[req]
|
||||
distinguished_name = req_distinguished_name
|
||||
prompt = no
|
||||
|
||||
[req_distinguished_name]
|
||||
C = US
|
||||
ST = Nevada
|
||||
L = Las Vegas
|
||||
O = github.com/lib/pq
|
||||
CN = pqgosslcert
|
||||
20
third_party/highgo-pq/certs/postgresql.crt
vendored
Normal file
20
third_party/highgo-pq/certs/postgresql.crt
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDPjCCAiYCCQD4nsC6zsmIqjANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJV
|
||||
UzEPMA0GA1UECAwGTmV2YWRhMRIwEAYDVQQHDAlMYXMgVmVnYXMxGjAYBgNVBAoM
|
||||
EWdpdGh1Yi5jb20vbGliL3BxMQ4wDAYDVQQDDAVwcSBDQTAeFw0yMTA5MDIwMTU1
|
||||
MDJaFw0zMTA5MDMwMTU1MDJaMGQxCzAJBgNVBAYTAlVTMQ8wDQYDVQQIDAZOZXZh
|
||||
ZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHViLmNvbS9saWIv
|
||||
cHExFDASBgNVBAMMC3BxZ29zc2xjZXJ0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
|
||||
MIIBCgKCAQEAx0ucPVUNCrVmbyithwWrmmZ1dGudBwhSyDB6af4z5Cr+S6dx2SRU
|
||||
UGUw3Lv+z+tUqQ7hJj0oNddIQeYKl/Tt6JPpZsQfERP/cUGedtyt7HnCKobBL+0B
|
||||
NvHnDIUiIL4LgfiZK4DWJkGmm7nTHo/7qKAw60vCMLUW98DC0Xhlk9MHYG+e9Zai
|
||||
3G0vY2X6DUYcSmzBI3JakFEgMZTQg3ofUQMz8TYeK3/DYadLXkl08d18LL3Dnefx
|
||||
0xRuBPNTa2tLfVnFkfFi6Z9xVB/WhG6+X4OLnO85v5xUOGTV+g154iR7FOkrrl5F
|
||||
lEUBj+yaIoTRi+MyZ/oYqWwQUDYS3+Te9wIDAQABMA0GCSqGSIb3DQEBCwUAA4IB
|
||||
AQCCJpwUWCx7xfXv3vH3LQcffZycyRHYPgTCbiQw3x9aBb77jUAh5O6lEj/W0nx2
|
||||
SCTEsCsRSAiFwfUb+g/AFCW84dELRWmf38eoqACebLymqnvxyZA+O87yu07XyFZR
|
||||
TnmbDMzZgsyWWGwS3JoGFk+ibWY4AImYQnSJO8Pi0kZ37ngbAyJ3RtDhhEQJWw/Q
|
||||
D04p3uky/ea7Gyz0QTx5o40n4gq7nEzF1OS6IHozM840J5aZrxRiXEa56fsmJHmI
|
||||
IGyI07SGlWJ15r1wc8lB+8ilnAqH1QQlYzTIW0Q4NZE7n3uQg1EVuueGiGO2ex2/
|
||||
he9lDiJfOQuPuLbOxzctP9v9
|
||||
-----END CERTIFICATE-----
|
||||
28
third_party/highgo-pq/certs/postgresql.key
vendored
Normal file
28
third_party/highgo-pq/certs/postgresql.key
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDHS5w9VQ0KtWZv
|
||||
KK2HBauaZnV0a50HCFLIMHpp/jPkKv5Lp3HZJFRQZTDcu/7P61SpDuEmPSg110hB
|
||||
5gqX9O3ok+lmxB8RE/9xQZ523K3secIqhsEv7QE28ecMhSIgvguB+JkrgNYmQaab
|
||||
udMej/uooDDrS8IwtRb3wMLReGWT0wdgb571lqLcbS9jZfoNRhxKbMEjclqQUSAx
|
||||
lNCDeh9RAzPxNh4rf8Nhp0teSXTx3XwsvcOd5/HTFG4E81Nra0t9WcWR8WLpn3FU
|
||||
H9aEbr5fg4uc7zm/nFQ4ZNX6DXniJHsU6SuuXkWURQGP7JoihNGL4zJn+hipbBBQ
|
||||
NhLf5N73AgMBAAECggEAHLNY1sRO0oH5NHzpMI6yfdPPimqM/JxIP6grmOQQ2QUQ
|
||||
BhkhHiJLOiC4frFcKtk7IfWQmw8noUlVkJfuYp/VOy9B55jK2IzGtqq6hWeWbH3E
|
||||
Zpdtbtd021LO8VCi75Au3BLPDCLLtEq0Ea0bKEWX+lrHcLtCRf1uR1OtOrlZ94Wl
|
||||
DUhm7YJC4cS1bi6Kdf03R+fw2oFi7/QdywcT4ow032jGWOly/Jl7bSHZK7xLtM/i
|
||||
9HfMwmusD/iuz7mtLU7VCpnlKZm6MfS5D427ybW8MruuiZEtQJ6QtRIrHBHk93aK
|
||||
Op0tjJ6tMav1UsJzgVz9+uWILE9l0AjAa4AvbfNzEQKBgQD8mma9SLQPtBb6cXuT
|
||||
CQgjE4vyph8mRnm/pTz3QLIpMiLy2+aKJD/u4cduzLw1vjuH1tlb7NQ9c891jAJh
|
||||
JhwDwqKAXfFicfRs/PYWngx/XtGhbbpgm1yA6XuYL1D06gzmjzXgHvZMOFcts+GF
|
||||
y0JEuV7v6eYrpQJRQYCwY6xTgwKBgQDJ+bHAlgOaC94DZEXZMiUznCCjBjAstiXG
|
||||
BEN7Cnfn6vgvPm/b6BkKn4VrsCmbZQKT7QJDSOhYwXCC2ZlrKiF8GEUHX4mi8347
|
||||
8B+DsuokTLNmN61QAZbb1c3XQVnr15xH8ijm7yYs4tCBmVLKBmpw1T4IZXXlVE5k
|
||||
gmee+AwIfQKBgGr+P0wnclVAc4cq8CusZKzux5VEtebxbPo21CbqWUxHtzPk3rZe
|
||||
elIFggK1Z3bgF7kG0NQ18QQCfLoOTqe1i6IwG8KBiA+pst1DHD0iPqroj6RvpMTs
|
||||
qXbU7ovcZs8GH+a8fBZtJufL6WkrSvfvyybu2X6HNP4Bi4S9WPPdlA1fAoGAE5m/
|
||||
vkjQoKp2KS4Z+TH8mj2UjT2Uf0JN+CGByvcBG+iZnTwZ7uVfSMCiWgkGgKYU0fY2
|
||||
OgFhSvu6x3gGg3fbOAfC6yxCVyX6IibzZ/x87HjlEA5nK1R8J2lgSHt3FoQeDn1Z
|
||||
qs+ajNCWG32doy1sNvb6xiXSgybjVK2zEKJRyKECgYBJTk2IABebjvInNb6tagcI
|
||||
nD4d2LgBmZJZsTruHXrpO0s3XCQcFKks4JKH1CVjd34f7LkxzEOGbE7wKBBd652s
|
||||
ob6gFKnbqTniTo3NRUycB6ymo4LSaBvKgeY5hYbVxrYheRLPGY+gPVYb3VMKu9N9
|
||||
76rcaFqJOz7OeywRG5bHUg==
|
||||
-----END PRIVATE KEY-----
|
||||
10
third_party/highgo-pq/certs/root.cnf
vendored
Normal file
10
third_party/highgo-pq/certs/root.cnf
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
[req]
|
||||
distinguished_name = req_distinguished_name
|
||||
prompt = no
|
||||
|
||||
[req_distinguished_name]
|
||||
C = US
|
||||
ST = Nevada
|
||||
L = Las Vegas
|
||||
O = github.com/lib/pq
|
||||
CN = pq CA
|
||||
24
third_party/highgo-pq/certs/root.crt
vendored
Normal file
24
third_party/highgo-pq/certs/root.crt
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEBjCCAu6gAwIBAgIJAPizR+OD14YnMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV
|
||||
BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG
|
||||
A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw
|
||||
MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowXjELMAkGA1UEBhMCVVMxDzANBgNVBAgM
|
||||
Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t
|
||||
L2xpYi9wcTEOMAwGA1UEAwwFcHEgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
|
||||
ggEKAoIBAQDb9d6sjdU6GdibGrXRMOHREH3MRUS8T4TFqGgPEGVDP/V5bAZlBSGP
|
||||
AN0o9DTyVLcbQpBt8zMTw9KeIzIIe5NIVkSmA16lw/YckGhOM+kZIkiDuE6qt5Ia
|
||||
OQCRMdXkZ8ejG/JUu+rHU8FJZL8DE+jyYherzdjkeVAQ7JfzxAwW2Dl7T/47g337
|
||||
Pwmf17AEb8ibSqmXyUN7R5NhJQs+hvaYdNagzdx91E1H+qlyBvmiNeasUQljLvZ+
|
||||
Y8wAuU79neA+d09O4PBiYwV17rSP6SZCeGE3oLZviL/0KM9Xig88oB+2FmvQ6Zxa
|
||||
L7SoBlqS+5pBZwpH7eee/wCIKAnJtMAJAgMBAAGjgcYwgcMwDwYDVR0TAQH/BAUw
|
||||
AwEB/zAdBgNVHQ4EFgQUfIXEczahbcM2cFrwclJF7GbdajkwgZAGA1UdIwSBiDCB
|
||||
hYAUfIXEczahbcM2cFrwclJF7GbdajmhYqRgMF4xCzAJBgNVBAYTAlVTMQ8wDQYD
|
||||
VQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHVi
|
||||
LmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBggkA+LNH44PXhicwDQYJKoZIhvcN
|
||||
AQELBQADggEBABFyGgSz2mHVJqYgX1Y+7P+MfKt83cV2uYDGYvXrLG2OGiCilVul
|
||||
oTBG+8omIMSHOsQZvWMpA5H0tnnlQHrKpKpUyKkSL+Wv5GL0UtBmHX7mVRiaK2l4
|
||||
q2BjRaQUitp/FH4NSdXtVrMME5T1JBBZHsQkNL3cNRzRKwY/Vj5UGEDxDS7lILUC
|
||||
e01L4oaK0iKQn4beALU+TvKoAHdPvoxpPpnhkF5ss9HmdcvRktJrKZemDJZswZ7/
|
||||
+omx8ZPIYYUH5VJJYYE88S7guAt+ZaKIUlel/t6xPbo2ZySFSg9u1uB99n+jTo3L
|
||||
1rAxFnN3FCX2jBqgP29xMVmisaN5k04UmyI=
|
||||
-----END CERTIFICATE-----
|
||||
29
third_party/highgo-pq/certs/server.cnf
vendored
Normal file
29
third_party/highgo-pq/certs/server.cnf
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
[ req ]
|
||||
default_bits = 2048
|
||||
distinguished_name = subject
|
||||
req_extensions = req_ext
|
||||
x509_extensions = x509_ext
|
||||
string_mask = utf8only
|
||||
prompt = no
|
||||
|
||||
[ subject ]
|
||||
C = US
|
||||
ST = Nevada
|
||||
L = Las Vegas
|
||||
O = github.com/lib/pq
|
||||
|
||||
[ x509_ext ]
|
||||
subjectKeyIdentifier = hash
|
||||
authorityKeyIdentifier = keyid,issuer
|
||||
|
||||
basicConstraints = CA:FALSE
|
||||
keyUsage = digitalSignature, keyEncipherment
|
||||
subjectAltName = DNS:postgres
|
||||
nsComment = "OpenSSL Generated Certificate"
|
||||
|
||||
[ req_ext ]
|
||||
subjectKeyIdentifier = hash
|
||||
basicConstraints = CA:FALSE
|
||||
keyUsage = digitalSignature, keyEncipherment
|
||||
subjectAltName = DNS:postgres
|
||||
nsComment = "OpenSSL Generated Certificate"
|
||||
22
third_party/highgo-pq/certs/server.crt
vendored
Normal file
22
third_party/highgo-pq/certs/server.crt
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDqzCCApOgAwIBAgIJAPiewLrOyYipMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV
|
||||
BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG
|
||||
A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw
|
||||
MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowTjELMAkGA1UEBhMCVVMxDzANBgNVBAgM
|
||||
Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t
|
||||
L2xpYi9wcTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKf6H4UzmANN
|
||||
QiQJe92Mf3ETMYmpZKNNO9DPEHyNLIkag+XwMrBTdcCK0mLvsNCYpXuBN6703KCd
|
||||
WAFOeMmj7gOsWtvjt5Xm6bRHLgegekXzcG/jDwq/wyzeDzr/YkITuIlG44Lf9lhY
|
||||
FLwiHlHOWHnwrZaEh6aU//02aQkzyX5INeXl/3TZm2G2eIH6AOxOKOU27MUsyVSQ
|
||||
5DE+SDKGcRP4bElueeQWvxAXNMZYb7sVSDdfHI3zr32K4k/tC8x0fZJ5XN/dvl4t
|
||||
4N4MrYlmDO5XOrb/gQH1H4iu6+5EMDfZYab4fkThnNFdfFqu4/8Scv7KZ8mWqpKM
|
||||
fGAjEPctQi0CAwEAAaN8MHowHQYDVR0OBBYEFENExPbmDyFB2AJUdbMvVyhlNPD5
|
||||
MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgWgMBMGA1UdEQQMMAqCCHBvc3RncmVzMCwG
|
||||
CWCGSAGG+EIBDQQfFh1PcGVuU1NMIEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTANBgkq
|
||||
hkiG9w0BAQsFAAOCAQEAMRVbV8RiEsmp9HAtnVCZmRXMIbgPGrqjeSwk586s4K8v
|
||||
BSqNCqxv6s5GfCRmDYiqSqeuCVDtUJS1HsTmbxVV7Ke71WMo+xHR1ICGKOa8WGCb
|
||||
TGsuicG5QZXWaxeMOg4s0qpKmKko0d1aErdVsanU5dkrVS7D6729Ffnzu4lwApk6
|
||||
invAB67p8u7sojwqRq5ce0vRaG+YFylTrWomF9kauEb8gKbQ9Xc7QfX+h+UH/mq9
|
||||
Nvdj8LOHp6/82bZdnsYUOtV4lS1IA/qzeXpqBphxqfWabD1yLtkyJyImZKq8uIPp
|
||||
0CG4jhObPdWcCkXD6bg3QK3mhwlC79OtFgxWmldCRQ==
|
||||
-----END CERTIFICATE-----
|
||||
28
third_party/highgo-pq/certs/server.key
vendored
Normal file
28
third_party/highgo-pq/certs/server.key
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCn+h+FM5gDTUIk
|
||||
CXvdjH9xEzGJqWSjTTvQzxB8jSyJGoPl8DKwU3XAitJi77DQmKV7gTeu9NygnVgB
|
||||
TnjJo+4DrFrb47eV5um0Ry4HoHpF83Bv4w8Kv8Ms3g86/2JCE7iJRuOC3/ZYWBS8
|
||||
Ih5Rzlh58K2WhIemlP/9NmkJM8l+SDXl5f902ZthtniB+gDsTijlNuzFLMlUkOQx
|
||||
PkgyhnET+GxJbnnkFr8QFzTGWG+7FUg3XxyN8699iuJP7QvMdH2SeVzf3b5eLeDe
|
||||
DK2JZgzuVzq2/4EB9R+IruvuRDA32WGm+H5E4ZzRXXxaruP/EnL+ymfJlqqSjHxg
|
||||
IxD3LUItAgMBAAECggEAOE2naQ9tIZYw2EFxikZApVcooJrtx6ropMnzHbx4NBB2
|
||||
K4mChAXFj184u77ZxmGT/jzGvFcI6LE0wWNbK0NOUV7hKZk/fPhkV3AQZrAMrAu4
|
||||
IVi7PwAd3JkmA8F8XuebUDA5rDGDsgL8GD9baFJA58abeLs9eMGyuF4XgOUh4bip
|
||||
hgHa76O2rcDWNY5HZqqRslw75FzlYkB0PCts/UJxSswj70kTTihyOhDlrm2TnyxI
|
||||
ne54UbGRrpfs9wiheSGLjDG81qZToBHQDwoAnjjZhu1VCaBISuGbgZrxyyRyqdnn
|
||||
xPW+KczMv04XyvF7v6Pz+bUEppalLXGiXnH5UtWvZQKBgQDTPCdMpNE/hwlq4nAw
|
||||
Kf42zIBWfbnMLVWYoeDiAOhtl9XAUAXn76xe6Rvo0qeAo67yejdbJfRq3HvGyw+q
|
||||
4PS8r9gXYmLYIPQxSoLL5+rFoBCN3qFippfjLB1j32mp7+15KjRj8FF2r6xIN8fu
|
||||
XatSRsaqmvCWYLDRv/rbHnxwkwKBgQDLkyfFLF7BtwtPWKdqrwOM7ip1UKh+oDBS
|
||||
vkCQ08aEFRBU7T3jChsx5GbaW6zmsSBwBwcrHclpSkz7n3aq19DDWObJR2p80Fma
|
||||
rsXeIcvtEpkvT3pVX268P5d+XGs1kxgFunqTysG9yChW+xzcs5MdKBzuMPPn7rL8
|
||||
MKAzdar6PwKBgEypkzW8x3h/4Moa3k6MnwdyVs2NGaZheaRIc95yJ+jGZzxBjrMr
|
||||
h+p2PbvU4BfO0AqOkpKRBtDVrlJqlggVVp04UHvEKE16QEW3Xhr0037f5cInX3j3
|
||||
Lz6yXwRFLAsR2aTUzWjL6jTh8uvO2s/GzQuyRh3a16Ar/WBShY+K0+zjAoGATnLT
|
||||
xZjWnyHRmu8X/PWakamJ9RFzDPDgDlLAgM8LVgTj+UY/LgnL9wsEU6s2UuP5ExKy
|
||||
QXxGDGwUhHar/SQTj+Pnc7Mwpw6HKSOmnnY5po8fNusSwml3O9XppEkrC0c236Y/
|
||||
7EobJO5IFVTJh4cv7vFxTJzSsRL8KFD4uzvh+nMCgYEAqY8NBYtIgNJA2B6C6hHF
|
||||
+bG7v46434ZHFfGTmMQwzE4taVg7YRnzYESAlvK4bAP5ZXR90n7GRGFhrXzoMZ38
|
||||
r0bw/q9rV+ReGda7/Bjf7ciCKiq0RODcHtf4IaskjPXCoQRGJtgCPLhWPfld6g9v
|
||||
/HTvO96xv9e3eG/PKSPog94=
|
||||
-----END PRIVATE KEY-----
|
||||
2395
third_party/highgo-pq/conn.go
vendored
Normal file
2395
third_party/highgo-pq/conn.go
vendored
Normal file
File diff suppressed because it is too large
Load Diff
8
third_party/highgo-pq/conn_go115.go
vendored
Normal file
8
third_party/highgo-pq/conn_go115.go
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build go1.15
|
||||
// +build go1.15
|
||||
|
||||
package pq
|
||||
|
||||
import "database/sql/driver"
|
||||
|
||||
var _ driver.Validator = &conn{}
|
||||
261
third_party/highgo-pq/conn_go18.go
vendored
Normal file
261
third_party/highgo-pq/conn_go18.go
vendored
Normal file
@@ -0,0 +1,261 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
watchCancelDialContextTimeout = time.Second * 10
|
||||
)
|
||||
|
||||
// Implement the "QueryerContext" interface
|
||||
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
list := make([]driver.Value, len(args))
|
||||
namedValueMap := map[string]int{}
|
||||
for i, nv := range args {
|
||||
list[i] = nv.Value
|
||||
if nv.Name != "" {
|
||||
namedValueMap[nv.Name] = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
if cn.isProcedure(list) {
|
||||
return cn.QueryProcedure(query, list)
|
||||
}
|
||||
|
||||
finish := cn.watchCancel(ctx)
|
||||
r, err := cn.query(query, list)
|
||||
if err != nil {
|
||||
if finish != nil {
|
||||
finish()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
r.finish = finish
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Implement the "ExecerContext" interface
|
||||
func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]driver.Value, len(args))
|
||||
namedValueMap := map[string]int{}
|
||||
for i, nv := range args {
|
||||
list[i] = nv.Value
|
||||
if nv.Name != "" {
|
||||
namedValueMap[nv.Name] = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
if finish := cn.watchCancel(ctx); finish != nil {
|
||||
defer finish()
|
||||
}
|
||||
|
||||
return cn.Exec(query, list)
|
||||
}
|
||||
|
||||
// Implement the "ConnPrepareContext" interface
|
||||
func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
if finish := cn.watchCancel(ctx); finish != nil {
|
||||
defer finish()
|
||||
}
|
||||
|
||||
return cn.Prepare(query)
|
||||
}
|
||||
|
||||
// Implement the "ConnBeginTx" interface
|
||||
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
var mode string
|
||||
|
||||
switch sql.IsolationLevel(opts.Isolation) {
|
||||
case sql.LevelDefault:
|
||||
// Don't touch mode: use the server's default
|
||||
case sql.LevelReadUncommitted:
|
||||
mode = " ISOLATION LEVEL READ UNCOMMITTED"
|
||||
case sql.LevelReadCommitted:
|
||||
mode = " ISOLATION LEVEL READ COMMITTED"
|
||||
case sql.LevelRepeatableRead:
|
||||
mode = " ISOLATION LEVEL REPEATABLE READ"
|
||||
case sql.LevelSerializable:
|
||||
mode = " ISOLATION LEVEL SERIALIZABLE"
|
||||
default:
|
||||
return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
|
||||
}
|
||||
|
||||
if opts.ReadOnly {
|
||||
mode += " READ ONLY"
|
||||
} else {
|
||||
mode += " READ WRITE"
|
||||
}
|
||||
|
||||
tx, err := cn.begin(mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn.txnFinish = cn.watchCancel(ctx)
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
func (cn *conn) Ping(ctx context.Context) error {
|
||||
if finish := cn.watchCancel(ctx); finish != nil {
|
||||
defer finish()
|
||||
}
|
||||
rows, err := cn.simpleQuery(";")
|
||||
if err != nil {
|
||||
return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
|
||||
}
|
||||
rows.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cn *conn) watchCancel(ctx context.Context) func() {
|
||||
if done := ctx.Done(); done != nil {
|
||||
finished := make(chan struct{}, 1)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
select {
|
||||
case finished <- struct{}{}:
|
||||
default:
|
||||
// We raced with the finish func, let the next query handle this with the
|
||||
// context.
|
||||
return
|
||||
}
|
||||
|
||||
// Set the connection state to bad so it does not get reused.
|
||||
cn.err.set(ctx.Err())
|
||||
|
||||
// At this point the function level context is canceled,
|
||||
// so it must not be used for the additional network
|
||||
// request to cancel the query.
|
||||
// Create a new context to pass into the dial.
|
||||
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
|
||||
defer cancel()
|
||||
|
||||
_ = cn.cancel(ctxCancel)
|
||||
case <-finished:
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
select {
|
||||
case <-finished:
|
||||
cn.err.set(ctx.Err())
|
||||
cn.Close()
|
||||
case finished <- struct{}{}:
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cn *conn) cancel(ctx context.Context) error {
|
||||
// Create a new values map (copy). This makes sure the connection created
|
||||
// in this method cannot write to the same underlying data, which could
|
||||
// cause a concurrent map write panic. This is necessary because cancel
|
||||
// is called from a goroutine in watchCancel.
|
||||
o := make(values)
|
||||
for k, v := range cn.opts {
|
||||
o[k] = v
|
||||
}
|
||||
|
||||
c, err := dial(ctx, cn.dialer, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
{
|
||||
can := conn{
|
||||
c: c,
|
||||
}
|
||||
err = can.ssl(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w := can.writeBuf(0)
|
||||
w.int32(80877102) // cancel request code
|
||||
w.int32(cn.processID)
|
||||
w.int32(cn.secretKey)
|
||||
|
||||
if err := can.sendStartupPacket(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Read until EOF to ensure that the server received the cancel.
|
||||
{
|
||||
_, err := io.Copy(ioutil.Discard, c)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Implement the "StmtQueryContext" interface
|
||||
func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
list := make([]driver.Value, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = nv.Value
|
||||
}
|
||||
finish := st.watchCancel(ctx)
|
||||
r, err := st.query(list)
|
||||
if err != nil {
|
||||
if finish != nil {
|
||||
finish()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
r.finish = finish
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Implement the "StmtExecContext" interface
|
||||
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]driver.Value, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = nv.Value
|
||||
}
|
||||
|
||||
if finish := st.watchCancel(ctx); finish != nil {
|
||||
defer finish()
|
||||
}
|
||||
|
||||
return st.Exec(list)
|
||||
}
|
||||
|
||||
// watchCancel is implemented on stmt in order to not mark the parent conn as bad
|
||||
func (st *stmt) watchCancel(ctx context.Context) func() {
|
||||
if done := ctx.Done(); done != nil {
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
// At this point the function level context is canceled,
|
||||
// so it must not be used for the additional network
|
||||
// request to cancel the query.
|
||||
// Create a new context to pass into the dial.
|
||||
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
|
||||
defer cancel()
|
||||
|
||||
_ = st.cancel(ctxCancel)
|
||||
finished <- struct{}{}
|
||||
case <-finished:
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
select {
|
||||
case <-finished:
|
||||
case finished <- struct{}{}:
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *stmt) cancel(ctx context.Context) error {
|
||||
return st.cn.cancel(ctx)
|
||||
}
|
||||
45
third_party/highgo-pq/conn_go19.go
vendored
Normal file
45
third_party/highgo-pq/conn_go19.go
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package pq
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var _ driver.NamedValueChecker = (*conn)(nil)
|
||||
|
||||
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
|
||||
if _, ok := nv.Value.(driver.Valuer); ok {
|
||||
// Ignore Valuer, for backward compatibility with pq.Array().
|
||||
return driver.ErrSkip
|
||||
}
|
||||
|
||||
// Ignoring []byte / []uint8.
|
||||
if _, ok := nv.Value.([]uint8); ok {
|
||||
return driver.ErrSkip
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(nv.Value)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Slice {
|
||||
var err error
|
||||
nv.Value, err = Array(v.Interface()).Value()
|
||||
return err
|
||||
}
|
||||
if v.Kind() == reflect.Struct {
|
||||
var err error
|
||||
switch nv.Value.(type) {
|
||||
case sql.Out:
|
||||
return err
|
||||
default:
|
||||
return driver.ErrSkip
|
||||
}
|
||||
}
|
||||
|
||||
return driver.ErrSkip
|
||||
}
|
||||
83
third_party/highgo-pq/conn_go19_test.go
vendored
Normal file
83
third_party/highgo-pq/conn_go19_test.go
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package pq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestArrayArg(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
for _, tc := range []struct {
|
||||
pgType string
|
||||
in, out interface{}
|
||||
}{
|
||||
{
|
||||
pgType: "int[]",
|
||||
in: []int{245, 231},
|
||||
out: []int64{245, 231},
|
||||
},
|
||||
{
|
||||
pgType: "int[]",
|
||||
in: &[]int{245, 231},
|
||||
out: []int64{245, 231},
|
||||
},
|
||||
{
|
||||
pgType: "int[]",
|
||||
in: []int64{245, 231},
|
||||
},
|
||||
{
|
||||
pgType: "int[]",
|
||||
in: &[]int64{245, 231},
|
||||
out: []int64{245, 231},
|
||||
},
|
||||
{
|
||||
pgType: "varchar[]",
|
||||
in: []string{"hello", "world"},
|
||||
},
|
||||
{
|
||||
pgType: "varchar[]",
|
||||
in: &[]string{"hello", "world"},
|
||||
out: []string{"hello", "world"},
|
||||
},
|
||||
} {
|
||||
if tc.out == nil {
|
||||
tc.out = tc.in
|
||||
}
|
||||
t.Run(fmt.Sprintf("%#v", tc.in), func(t *testing.T) {
|
||||
r, err := db.Query(fmt.Sprintf("SELECT $1::%s", tc.pgType), tc.in)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(r.Err())
|
||||
}
|
||||
t.Fatal("expected row")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r.Next() {
|
||||
t.Fatal("unexpected row")
|
||||
}
|
||||
}()
|
||||
|
||||
got := reflect.New(reflect.TypeOf(tc.out))
|
||||
if err := r.Scan(Array(got.Interface())); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.out, got.Elem().Interface()) {
|
||||
t.Errorf("got %v, want %v", got, tc.out)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
1973
third_party/highgo-pq/conn_test.go
vendored
Normal file
1973
third_party/highgo-pq/conn_test.go
vendored
Normal file
File diff suppressed because it is too large
Load Diff
120
third_party/highgo-pq/connector.go
vendored
Normal file
120
third_party/highgo-pq/connector.go
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Connector represents a fixed configuration for the pq driver with a given
|
||||
// name. Connector satisfies the database/sql/driver Connector interface and
|
||||
// can be used to create any number of DB Conn's via the database/sql OpenDB
|
||||
// function.
|
||||
//
|
||||
// See https://golang.org/pkg/database/sql/driver/#Connector.
|
||||
// See https://golang.org/pkg/database/sql/#OpenDB.
|
||||
type Connector struct {
|
||||
opts values
|
||||
dialer Dialer
|
||||
}
|
||||
|
||||
// Connect returns a connection to the database using the fixed configuration
|
||||
// of this Connector. Context is not used.
|
||||
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
return c.open(ctx)
|
||||
}
|
||||
|
||||
// Dialer allows change the dialer used to open connections.
|
||||
func (c *Connector) Dialer(dialer Dialer) {
|
||||
c.dialer = dialer
|
||||
}
|
||||
|
||||
// Driver returns the underlying driver of this Connector.
|
||||
func (c *Connector) Driver() driver.Driver {
|
||||
return &Driver{}
|
||||
}
|
||||
|
||||
// NewConnector returns a connector for the pq driver in a fixed configuration
|
||||
// with the given dsn. The returned connector can be used to create any number
|
||||
// of equivalent Conn's. The returned connector is intended to be used with
|
||||
// database/sql.OpenDB.
|
||||
//
|
||||
// See https://golang.org/pkg/database/sql/driver/#Connector.
|
||||
// See https://golang.org/pkg/database/sql/#OpenDB.
|
||||
func NewConnector(dsn string) (*Connector, error) {
|
||||
var err error
|
||||
o := make(values)
|
||||
|
||||
// A number of defaults are applied here, in this order:
|
||||
//
|
||||
// * Very low precedence defaults applied in every situation
|
||||
// * Environment variables
|
||||
// * Explicitly passed connection information
|
||||
o["host"] = "localhost"
|
||||
o["port"] = "5432"
|
||||
// N.B.: Extra float digits should be set to 3, but that breaks
|
||||
// Postgres 8.4 and older, where the max is 2.
|
||||
o["extra_float_digits"] = "2"
|
||||
for k, v := range parseEnviron(os.Environ()) {
|
||||
o[k] = v
|
||||
}
|
||||
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
dsn, err = ParseURL(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := parseOpts(dsn, o); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use the "fallback" application name if necessary
|
||||
if fallback, ok := o["fallback_application_name"]; ok {
|
||||
if _, ok := o["application_name"]; !ok {
|
||||
o["application_name"] = fallback
|
||||
}
|
||||
}
|
||||
|
||||
// We can't work with any client_encoding other than UTF-8 currently.
|
||||
// However, we have historically allowed the user to set it to UTF-8
|
||||
// explicitly, and there's no reason to break such programs, so allow that.
|
||||
// Note that the "options" setting could also set client_encoding, but
|
||||
// parsing its value is not worth it. Instead, we always explicitly send
|
||||
// client_encoding as a separate run-time parameter, which should override
|
||||
// anything set in options.
|
||||
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
|
||||
return nil, errors.New("client_encoding must be absent or 'UTF8'")
|
||||
}
|
||||
o["client_encoding"] = "UTF8"
|
||||
// DateStyle needs a similar treatment.
|
||||
if datestyle, ok := o["datestyle"]; ok {
|
||||
if datestyle != "ISO, MDY" {
|
||||
return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)
|
||||
}
|
||||
} else {
|
||||
o["datestyle"] = "ISO, MDY"
|
||||
}
|
||||
|
||||
// If a user is not provided by any other means, the last
|
||||
// resort is to use the current operating system provided user
|
||||
// name.
|
||||
if _, ok := o["user"]; !ok {
|
||||
u, err := userCurrent()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o["user"] = u
|
||||
}
|
||||
|
||||
// SSL is not necessary or supported over UNIX domain sockets
|
||||
if network, _ := network(o); network == "unix" {
|
||||
o["sslmode"] = "disable"
|
||||
}
|
||||
|
||||
return &Connector{opts: o, dialer: defaultDialer{}}, nil
|
||||
}
|
||||
30
third_party/highgo-pq/connector_example_test.go
vendored
Normal file
30
third_party/highgo-pq/connector_example_test.go
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build go1.10
|
||||
// +build go1.10
|
||||
|
||||
package pq_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func ExampleNewConnector() {
|
||||
name := ""
|
||||
connector, err := pq.NewConnector(name)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
db := sql.OpenDB(connector)
|
||||
defer db.Close()
|
||||
|
||||
// Use the DB
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
txn.Rollback()
|
||||
}
|
||||
87
third_party/highgo-pq/connector_test.go
vendored
Normal file
87
third_party/highgo-pq/connector_test.go
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
//go:build go1.10
|
||||
// +build go1.10
|
||||
|
||||
package pq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewConnector_WorksWithOpenDB(t *testing.T) {
|
||||
name := ""
|
||||
c, err := NewConnector(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db := sql.OpenDB(c)
|
||||
defer db.Close()
|
||||
// database/sql might not call our Open at all unless we do something with
|
||||
// the connection
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
txn.Rollback()
|
||||
}
|
||||
|
||||
func TestNewConnector_Connect(t *testing.T) {
|
||||
name := ""
|
||||
c, err := NewConnector(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db, err := c.Connect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
// database/sql might not call our Open at all unless we do something with
|
||||
// the connection
|
||||
txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
txn.Rollback()
|
||||
}
|
||||
|
||||
func TestNewConnector_Driver(t *testing.T) {
|
||||
name := ""
|
||||
c, err := NewConnector(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db, err := c.Driver().Open(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
// database/sql might not call our Open at all unless we do something with
|
||||
// the connection
|
||||
txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
txn.Rollback()
|
||||
}
|
||||
|
||||
func TestNewConnector_Environ(t *testing.T) {
|
||||
name := ""
|
||||
os.Setenv("PGPASSFILE", "/tmp/.pgpass")
|
||||
defer os.Unsetenv("PGPASSFILE")
|
||||
c, err := NewConnector(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for key, expected := range map[string]string{
|
||||
"passfile": "/tmp/.pgpass",
|
||||
} {
|
||||
if got := c.opts[key]; got != expected {
|
||||
t.Fatalf("Getting values from environment variables, for %v expected %s got %s", key, expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
348
third_party/highgo-pq/copy.go
vendored
Normal file
348
third_party/highgo-pq/copy.go
vendored
Normal file
@@ -0,0 +1,348 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
errCopyInClosed = errors.New("pq: copyin statement has already been closed")
|
||||
errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
|
||||
errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
|
||||
errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
|
||||
errCopyInProgress = errors.New("pq: COPY in progress")
|
||||
)
|
||||
|
||||
// CopyIn creates a COPY FROM statement which can be prepared with
|
||||
// Tx.Prepare(). The target table should be visible in search_path.
|
||||
func CopyIn(table string, columns ...string) string {
|
||||
buffer := bytes.NewBufferString("COPY ")
|
||||
BufferQuoteIdentifier(table, buffer)
|
||||
buffer.WriteString(" (")
|
||||
makeStmt(buffer, columns...)
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
// MakeStmt makes the stmt string for CopyIn and CopyInSchema.
|
||||
func makeStmt(buffer *bytes.Buffer, columns ...string) {
|
||||
//s := bytes.NewBufferString()
|
||||
for i, col := range columns {
|
||||
if i != 0 {
|
||||
buffer.WriteString(", ")
|
||||
}
|
||||
BufferQuoteIdentifier(col, buffer)
|
||||
}
|
||||
buffer.WriteString(") FROM STDIN")
|
||||
}
|
||||
|
||||
// CopyInSchema creates a COPY FROM statement which can be prepared with
|
||||
// Tx.Prepare().
|
||||
func CopyInSchema(schema, table string, columns ...string) string {
|
||||
buffer := bytes.NewBufferString("COPY ")
|
||||
BufferQuoteIdentifier(schema, buffer)
|
||||
buffer.WriteRune('.')
|
||||
BufferQuoteIdentifier(table, buffer)
|
||||
buffer.WriteString(" (")
|
||||
makeStmt(buffer, columns...)
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
type copyin struct {
|
||||
cn *conn
|
||||
buffer []byte
|
||||
rowData chan []byte
|
||||
done chan bool
|
||||
|
||||
closed bool
|
||||
|
||||
mu struct {
|
||||
sync.Mutex
|
||||
err error
|
||||
driver.Result
|
||||
}
|
||||
}
|
||||
|
||||
const ciBufferSize = 64 * 1024
|
||||
|
||||
// flush buffer before the buffer is filled up and needs reallocation
|
||||
const ciBufferFlushSize = 63 * 1024
|
||||
|
||||
func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
|
||||
if !cn.isInTransaction() {
|
||||
return nil, errCopyNotSupportedOutsideTxn
|
||||
}
|
||||
|
||||
ci := ©in{
|
||||
cn: cn,
|
||||
buffer: make([]byte, 0, ciBufferSize),
|
||||
rowData: make(chan []byte),
|
||||
done: make(chan bool, 1),
|
||||
}
|
||||
// add CopyData identifier + 4 bytes for message length
|
||||
ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
|
||||
|
||||
b := cn.writeBuf('Q')
|
||||
b.string(q)
|
||||
cn.send(b)
|
||||
|
||||
awaitCopyInResponse:
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case 'G':
|
||||
if r.byte() != 0 {
|
||||
err = errBinaryCopyNotSupported
|
||||
break awaitCopyInResponse
|
||||
}
|
||||
go ci.resploop()
|
||||
return ci, nil
|
||||
case 'H':
|
||||
err = errCopyToNotSupported
|
||||
break awaitCopyInResponse
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'Z':
|
||||
if err == nil {
|
||||
ci.setBad(driver.ErrBadConn)
|
||||
errorf("unexpected ReadyForQuery in response to COPY")
|
||||
}
|
||||
cn.processReadyForQuery(r)
|
||||
return nil, err
|
||||
default:
|
||||
ci.setBad(driver.ErrBadConn)
|
||||
errorf("unknown response for copy query: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
// something went wrong, abort COPY before we return
|
||||
b = cn.writeBuf('f')
|
||||
b.string(err.Error())
|
||||
cn.send(b)
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case 'c', 'C', 'E':
|
||||
case 'Z':
|
||||
// correctly aborted, we're done
|
||||
cn.processReadyForQuery(r)
|
||||
return nil, err
|
||||
default:
|
||||
ci.setBad(driver.ErrBadConn)
|
||||
errorf("unknown response for CopyFail: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ci *copyin) flush(buf []byte) {
|
||||
// set message length (without message identifier)
|
||||
binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
|
||||
|
||||
_, err := ci.cn.c.Write(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ci *copyin) resploop() {
|
||||
for {
|
||||
var r readBuf
|
||||
t, err := ci.cn.recvMessage(&r)
|
||||
if err != nil {
|
||||
ci.setBad(driver.ErrBadConn)
|
||||
ci.setError(err)
|
||||
ci.done <- true
|
||||
return
|
||||
}
|
||||
switch t {
|
||||
case 'C':
|
||||
// complete
|
||||
res, _ := ci.cn.parseComplete(r.string())
|
||||
ci.setResult(res)
|
||||
case 'N':
|
||||
if n := ci.cn.noticeHandler; n != nil {
|
||||
n(parseError(&r))
|
||||
}
|
||||
case 'Z':
|
||||
ci.cn.processReadyForQuery(&r)
|
||||
ci.done <- true
|
||||
return
|
||||
case 'E':
|
||||
err := parseError(&r)
|
||||
ci.setError(err)
|
||||
default:
|
||||
ci.setBad(driver.ErrBadConn)
|
||||
ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
|
||||
ci.done <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ci *copyin) setBad(err error) {
|
||||
ci.cn.err.set(err)
|
||||
}
|
||||
|
||||
func (ci *copyin) getBad() error {
|
||||
return ci.cn.err.get()
|
||||
}
|
||||
|
||||
func (ci *copyin) err() error {
|
||||
ci.mu.Lock()
|
||||
err := ci.mu.err
|
||||
ci.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// setError() sets ci.err if one has not been set already. Caller must not be
|
||||
// holding ci.Mutex.
|
||||
func (ci *copyin) setError(err error) {
|
||||
ci.mu.Lock()
|
||||
if ci.mu.err == nil {
|
||||
ci.mu.err = err
|
||||
}
|
||||
ci.mu.Unlock()
|
||||
}
|
||||
|
||||
func (ci *copyin) setResult(result driver.Result) {
|
||||
ci.mu.Lock()
|
||||
ci.mu.Result = result
|
||||
ci.mu.Unlock()
|
||||
}
|
||||
|
||||
func (ci *copyin) getResult() driver.Result {
|
||||
ci.mu.Lock()
|
||||
result := ci.mu.Result
|
||||
ci.mu.Unlock()
|
||||
if result == nil {
|
||||
return driver.RowsAffected(0)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (ci *copyin) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
|
||||
return nil, ErrNotSupported
|
||||
}
|
||||
|
||||
// Exec inserts values into the COPY stream. The insert is asynchronous
|
||||
// and Exec can return errors from previous Exec calls to the same
|
||||
// COPY stmt.
|
||||
//
|
||||
// You need to call Exec(nil) to sync the COPY stream and to get any
|
||||
// errors from pending data, since Stmt.Close() doesn't return errors
|
||||
// to the user.
|
||||
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
|
||||
if ci.closed {
|
||||
return nil, errCopyInClosed
|
||||
}
|
||||
|
||||
if err := ci.getBad(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ci.cn.errRecover(&err)
|
||||
|
||||
if err := ci.err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(v) == 0 {
|
||||
if err := ci.Close(); err != nil {
|
||||
return driver.RowsAffected(0), err
|
||||
}
|
||||
|
||||
return ci.getResult(), nil
|
||||
}
|
||||
|
||||
numValues := len(v)
|
||||
for i, value := range v {
|
||||
ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
|
||||
if i < numValues-1 {
|
||||
ci.buffer = append(ci.buffer, '\t')
|
||||
}
|
||||
}
|
||||
|
||||
ci.buffer = append(ci.buffer, '\n')
|
||||
|
||||
if len(ci.buffer) > ciBufferFlushSize {
|
||||
ci.flush(ci.buffer)
|
||||
// reset buffer, keep bytes for message identifier and length
|
||||
ci.buffer = ci.buffer[:5]
|
||||
}
|
||||
|
||||
return driver.RowsAffected(0), nil
|
||||
}
|
||||
|
||||
// CopyData inserts a raw string into the COPY stream. The insert is
|
||||
// asynchronous and CopyData can return errors from previous CopyData calls to
|
||||
// the same COPY stmt.
|
||||
//
|
||||
// You need to call Exec(nil) to sync the COPY stream and to get any
|
||||
// errors from pending data, since Stmt.Close() doesn't return errors
|
||||
// to the user.
|
||||
func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) {
|
||||
if ci.closed {
|
||||
return nil, errCopyInClosed
|
||||
}
|
||||
|
||||
if finish := ci.cn.watchCancel(ctx); finish != nil {
|
||||
defer finish()
|
||||
}
|
||||
|
||||
if err := ci.getBad(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ci.cn.errRecover(&err)
|
||||
|
||||
if err := ci.err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ci.buffer = append(ci.buffer, []byte(line)...)
|
||||
ci.buffer = append(ci.buffer, '\n')
|
||||
|
||||
if len(ci.buffer) > ciBufferFlushSize {
|
||||
ci.flush(ci.buffer)
|
||||
// reset buffer, keep bytes for message identifier and length
|
||||
ci.buffer = ci.buffer[:5]
|
||||
}
|
||||
|
||||
return driver.RowsAffected(0), nil
|
||||
}
|
||||
|
||||
func (ci *copyin) Close() (err error) {
|
||||
if ci.closed { // Don't do anything, we're already closed
|
||||
return nil
|
||||
}
|
||||
ci.closed = true
|
||||
|
||||
if err := ci.getBad(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer ci.cn.errRecover(&err)
|
||||
|
||||
if len(ci.buffer) > 0 {
|
||||
ci.flush(ci.buffer)
|
||||
}
|
||||
// Avoid touching the scratch buffer as resploop could be using it.
|
||||
err = ci.cn.sendSimpleMessage('c')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-ci.done
|
||||
ci.cn.inCopy = false
|
||||
|
||||
if err := ci.err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
510
third_party/highgo-pq/copy_test.go
vendored
Normal file
510
third_party/highgo-pq/copy_test.go
vendored
Normal file
File diff suppressed because one or more lines are too long
268
third_party/highgo-pq/doc.go
vendored
Normal file
268
third_party/highgo-pq/doc.go
vendored
Normal file
@@ -0,0 +1,268 @@
|
||||
/*
|
||||
Package pq is a pure Go Postgres driver for the database/sql package.
|
||||
|
||||
In most cases clients will use the database/sql package instead of
|
||||
using this package directly. For example:
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func main() {
|
||||
connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full"
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
age := 21
|
||||
rows, err := db.Query("SELECT name FROM users WHERE age = $1", age)
|
||||
…
|
||||
}
|
||||
|
||||
You can also connect to a database using a URL. For example:
|
||||
|
||||
connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full"
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
|
||||
|
||||
Connection String Parameters
|
||||
|
||||
|
||||
Similarly to libpq, when establishing a connection using pq you are expected to
|
||||
supply a connection string containing zero or more parameters.
|
||||
A subset of the connection parameters supported by libpq are also supported by pq.
|
||||
Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem)
|
||||
directly in the connection string. This is different from libpq, which does not allow
|
||||
run-time parameters in the connection string, instead requiring you to supply
|
||||
them in the options parameter.
|
||||
|
||||
For compatibility with libpq, the following special connection parameters are
|
||||
supported:
|
||||
|
||||
* dbname - The name of the database to connect to
|
||||
* user - The user to sign in as
|
||||
* password - The user's password
|
||||
* host - The host to connect to. Values that start with / are for unix
|
||||
domain sockets. (default is localhost)
|
||||
* port - The port to bind to. (default is 5432)
|
||||
* sslmode - Whether or not to use SSL (default is require, this is not
|
||||
the default for libpq)
|
||||
* fallback_application_name - An application_name to fall back to if one isn't provided.
|
||||
* connect_timeout - Maximum wait for connection, in seconds. Zero or
|
||||
not specified means wait indefinitely.
|
||||
* sslcert - Cert file location. The file must contain PEM encoded data.
|
||||
* sslkey - Key file location. The file must contain PEM encoded data.
|
||||
* sslrootcert - The location of the root certificate file. The file
|
||||
must contain PEM encoded data.
|
||||
|
||||
Valid values for sslmode are:
|
||||
|
||||
* disable - No SSL
|
||||
* require - Always SSL (skip verification)
|
||||
* verify-ca - Always SSL (verify that the certificate presented by the
|
||||
server was signed by a trusted CA)
|
||||
* verify-full - Always SSL (verify that the certification presented by
|
||||
the server was signed by a trusted CA and the server host name
|
||||
matches the one in the certificate)
|
||||
|
||||
See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
|
||||
for more information about connection string parameters.
|
||||
|
||||
Use single quotes for values that contain whitespace:
|
||||
|
||||
"user=pqgotest password='with spaces'"
|
||||
|
||||
A backslash will escape the next character in values:
|
||||
|
||||
"user=space\ man password='it\'s valid'"
|
||||
|
||||
Note that the connection parameter client_encoding (which sets the
|
||||
text encoding for the connection) may be set but must be "UTF8",
|
||||
matching with the same rules as Postgres. It is an error to provide
|
||||
any other value.
|
||||
|
||||
In addition to the parameters listed above, any run-time parameter that can be
|
||||
set at backend start time can be set in the connection string. For more
|
||||
information, see
|
||||
http://www.postgresql.org/docs/current/static/runtime-config.html.
|
||||
|
||||
Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html
|
||||
supported by libpq are also supported by pq. If any of the environment
|
||||
variables not supported by pq are set, pq will panic during connection
|
||||
establishment. Environment variables have a lower precedence than explicitly
|
||||
provided connection parameters.
|
||||
|
||||
The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html
|
||||
is supported, but on Windows PGPASSFILE must be specified explicitly.
|
||||
|
||||
|
||||
Queries
|
||||
|
||||
|
||||
database/sql does not dictate any specific format for parameter
|
||||
markers in query strings, and pq uses the Postgres-native ordinal markers,
|
||||
as shown above. The same marker can be reused for the same parameter:
|
||||
|
||||
rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1
|
||||
OR age BETWEEN $2 AND $2 + 3`, "orange", 64)
|
||||
|
||||
pq does not support the LastInsertId() method of the Result type in database/sql.
|
||||
To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres
|
||||
RETURNING clause with a standard Query or QueryRow call:
|
||||
|
||||
var userid int
|
||||
err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age)
|
||||
VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid)
|
||||
|
||||
For more details on RETURNING, see the Postgres documentation:
|
||||
|
||||
http://www.postgresql.org/docs/current/static/sql-insert.html
|
||||
http://www.postgresql.org/docs/current/static/sql-update.html
|
||||
http://www.postgresql.org/docs/current/static/sql-delete.html
|
||||
|
||||
For additional instructions on querying see the documentation for the database/sql package.
|
||||
|
||||
|
||||
Data Types
|
||||
|
||||
|
||||
Parameters pass through driver.DefaultParameterConverter before they are handled
|
||||
by this package. When the binary_parameters connection option is enabled,
|
||||
[]byte values are sent directly to the backend as data in binary format.
|
||||
|
||||
This package returns the following types for values from the PostgreSQL backend:
|
||||
|
||||
- integer types smallint, integer, and bigint are returned as int64
|
||||
- floating-point types real and double precision are returned as float64
|
||||
- character types char, varchar, and text are returned as string
|
||||
- temporal types date, time, timetz, timestamp, and timestamptz are
|
||||
returned as time.Time
|
||||
- the boolean type is returned as bool
|
||||
- the bytea type is returned as []byte
|
||||
|
||||
All other types are returned directly from the backend as []byte values in text format.
|
||||
|
||||
|
||||
Errors
|
||||
|
||||
|
||||
pq may return errors of type *pq.Error which can be interrogated for error details:
|
||||
|
||||
if err, ok := err.(*pq.Error); ok {
|
||||
fmt.Println("pq error:", err.Code.Name())
|
||||
}
|
||||
|
||||
See the pq.Error type for details.
|
||||
|
||||
|
||||
Bulk imports
|
||||
|
||||
You can perform bulk imports by preparing a statement returned by pq.CopyIn (or
|
||||
pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement
|
||||
handle can then be repeatedly "executed" to copy data into the target table.
|
||||
After all data has been processed you should call Exec() once with no arguments
|
||||
to flush all buffered data. Any call to Exec() might return an error which
|
||||
should be handled appropriately, but because of the internal buffering an error
|
||||
returned by Exec() might not be related to the data passed in the call that
|
||||
failed.
|
||||
|
||||
CopyIn uses COPY FROM internally. It is not possible to COPY outside of an
|
||||
explicit transaction in pq.
|
||||
|
||||
Usage example:
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age"))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
_, err = stmt.Exec(user.Name, int64(user.Age))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = txn.Commit()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
Notifications
|
||||
|
||||
|
||||
PostgreSQL supports a simple publish/subscribe model over database
|
||||
connections. See http://www.postgresql.org/docs/current/static/sql-notify.html
|
||||
for more information about the general mechanism.
|
||||
|
||||
To start listening for notifications, you first have to open a new connection
|
||||
to the database by calling NewListener. This connection can not be used for
|
||||
anything other than LISTEN / NOTIFY. Calling Listen will open a "notification
|
||||
channel"; once a notification channel is open, a notification generated on that
|
||||
channel will effect a send on the Listener.Notify channel. A notification
|
||||
channel will remain open until Unlisten is called, though connection loss might
|
||||
result in some notifications being lost. To solve this problem, Listener sends
|
||||
a nil pointer over the Notify channel any time the connection is re-established
|
||||
following a connection loss. The application can get information about the
|
||||
state of the underlying connection by setting an event callback in the call to
|
||||
NewListener.
|
||||
|
||||
A single Listener can safely be used from concurrent goroutines, which means
|
||||
that there is often no need to create more than one Listener in your
|
||||
application. However, a Listener is always connected to a single database, so
|
||||
you will need to create a new Listener instance for every database you want to
|
||||
receive notifications in.
|
||||
|
||||
The channel name in both Listen and Unlisten is case sensitive, and can contain
|
||||
any characters legal in an identifier (see
|
||||
http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
for more information). Note that the channel name will be truncated to 63
|
||||
bytes by the PostgreSQL server.
|
||||
|
||||
You can find a complete, working example of Listener usage at
|
||||
https://godoc.org/github.com/lib/pq/example/listen.
|
||||
|
||||
|
||||
Kerberos Support
|
||||
|
||||
|
||||
If you need support for Kerberos authentication, add the following to your main
|
||||
package:
|
||||
|
||||
import "github.com/lib/pq/auth/kerberos"
|
||||
|
||||
func init() {
|
||||
pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() })
|
||||
}
|
||||
|
||||
This package is in a separate module so that users who don't need Kerberos
|
||||
don't have to download unnecessary dependencies.
|
||||
|
||||
When imported, additional connection string parameters are supported:
|
||||
|
||||
* krbsrvname - GSS (Kerberos) service name when constructing the
|
||||
SPN (default is `postgres`). This will be combined with the host
|
||||
to form the full SPN: `krbsrvname/host`.
|
||||
* krbspn - GSS (Kerberos) SPN. This takes priority over
|
||||
`krbsrvname` if present.
|
||||
*/
|
||||
package pq
|
||||
632
third_party/highgo-pq/encode.go
vendored
Normal file
632
third_party/highgo-pq/encode.go
vendored
Normal file
@@ -0,0 +1,632 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`)
|
||||
|
||||
func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
|
||||
switch v := x.(type) {
|
||||
case []byte:
|
||||
return v
|
||||
default:
|
||||
return encode(parameterStatus, x, oid.T_unknown)
|
||||
}
|
||||
}
|
||||
|
||||
func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return strconv.AppendInt(nil, v, 10)
|
||||
case float64:
|
||||
return strconv.AppendFloat(nil, v, 'f', -1, 64)
|
||||
case []byte:
|
||||
if pgtypOid == oid.T_bytea {
|
||||
return encodeBytea(parameterStatus.serverVersion, v)
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
if pgtypOid == oid.T_bytea {
|
||||
return encodeBytea(parameterStatus.serverVersion, []byte(v))
|
||||
}
|
||||
|
||||
return []byte(v)
|
||||
case bool:
|
||||
return strconv.AppendBool(nil, v)
|
||||
case time.Time:
|
||||
return formatTs(v)
|
||||
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
|
||||
switch f {
|
||||
case formatBinary:
|
||||
return binaryDecode(parameterStatus, s, typ)
|
||||
case formatText:
|
||||
return textDecode(parameterStatus, s, typ)
|
||||
default:
|
||||
panic("not reached")
|
||||
}
|
||||
}
|
||||
|
||||
func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
|
||||
switch typ {
|
||||
case oid.T_bytea:
|
||||
return s
|
||||
case oid.T_int8:
|
||||
return int64(binary.BigEndian.Uint64(s))
|
||||
case oid.T_int4:
|
||||
return int64(int32(binary.BigEndian.Uint32(s)))
|
||||
case oid.T_int2:
|
||||
return int64(int16(binary.BigEndian.Uint16(s)))
|
||||
case oid.T_uuid:
|
||||
b, err := decodeUUIDBinary(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
|
||||
default:
|
||||
errorf("don't know how to decode binary parameter of type %d", uint32(typ))
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
|
||||
switch typ {
|
||||
case oid.T_char, oid.T_varchar, oid.T_text, oid.T_refcursor, oid.T__refcursor:
|
||||
return string(s)
|
||||
case oid.T_bytea:
|
||||
b, err := parseBytea(s)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return b
|
||||
case oid.T_timestamptz:
|
||||
return parseTs(parameterStatus.currentLocation, string(s))
|
||||
case oid.T_timestamp, oid.T_date:
|
||||
return parseTs(nil, string(s))
|
||||
case oid.T_time:
|
||||
return mustParse("15:04:05", typ, s)
|
||||
case oid.T_timetz:
|
||||
return mustParse("15:04:05-07", typ, s)
|
||||
case oid.T_bool:
|
||||
return s[0] == 't'
|
||||
case oid.T_int8, oid.T_int4, oid.T_int2:
|
||||
i, err := strconv.ParseInt(string(s), 10, 64)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return i
|
||||
case oid.T_float4, oid.T_float8, oid.T_numeric:
|
||||
// We always use 64 bit parsing, regardless of whether the input text is for
|
||||
// a float4 or float8, because clients expect float64s for all float datatypes
|
||||
// and returning a 32-bit parsed float64 produces lossy results.
|
||||
f, err := strconv.ParseFloat(string(s), 64)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// appendEncodedText encodes item in text format as required by COPY
|
||||
// and appends to buf
|
||||
func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return strconv.AppendInt(buf, v, 10)
|
||||
case float64:
|
||||
return strconv.AppendFloat(buf, v, 'f', -1, 64)
|
||||
case []byte:
|
||||
encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
|
||||
return appendEscapedText(buf, string(encodedBytea))
|
||||
case string:
|
||||
return appendEscapedText(buf, v)
|
||||
case bool:
|
||||
return strconv.AppendBool(buf, v)
|
||||
case time.Time:
|
||||
return append(buf, formatTs(v)...)
|
||||
case nil:
|
||||
return append(buf, "\\N"...)
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func appendEscapedText(buf []byte, text string) []byte {
|
||||
escapeNeeded := false
|
||||
startPos := 0
|
||||
var c byte
|
||||
|
||||
// check if we need to escape
|
||||
for i := 0; i < len(text); i++ {
|
||||
c = text[i]
|
||||
if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
|
||||
escapeNeeded = true
|
||||
startPos = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if !escapeNeeded {
|
||||
return append(buf, text...)
|
||||
}
|
||||
|
||||
// copy till first char to escape, iterate the rest
|
||||
result := append(buf, text[:startPos]...)
|
||||
for i := startPos; i < len(text); i++ {
|
||||
c = text[i]
|
||||
switch c {
|
||||
case '\\':
|
||||
result = append(result, '\\', '\\')
|
||||
case '\n':
|
||||
result = append(result, '\\', 'n')
|
||||
case '\r':
|
||||
result = append(result, '\\', 'r')
|
||||
case '\t':
|
||||
result = append(result, '\\', 't')
|
||||
default:
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mustParse(f string, typ oid.Oid, s []byte) time.Time {
|
||||
str := string(s)
|
||||
|
||||
// Check for a minute and second offset in the timezone.
|
||||
if typ == oid.T_timestamptz || typ == oid.T_timetz {
|
||||
for i := 3; i <= 6; i += 3 {
|
||||
if str[len(str)-i] == ':' {
|
||||
f += ":00"
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Special case for 24:00 time.
|
||||
// Unfortunately, golang does not parse 24:00 as a proper time.
|
||||
// In this case, we want to try "round to the next day", to differentiate.
|
||||
// As such, we find if the 24:00 time matches at the beginning; if so,
|
||||
// we default it back to 00:00 but add a day later.
|
||||
var is2400Time bool
|
||||
switch typ {
|
||||
case oid.T_timetz, oid.T_time:
|
||||
if matches := time2400Regex.FindStringSubmatch(str); matches != nil {
|
||||
// Concatenate timezone information at the back.
|
||||
str = "00:00:00" + str[len(matches[1]):]
|
||||
is2400Time = true
|
||||
}
|
||||
}
|
||||
t, err := time.Parse(f, str)
|
||||
if err != nil {
|
||||
errorf("decode: %s", err)
|
||||
}
|
||||
if is2400Time {
|
||||
t = t.Add(24 * time.Hour)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
var errInvalidTimestamp = errors.New("invalid timestamp")
|
||||
|
||||
type timestampParser struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (p *timestampParser) expect(str string, char byte, pos int) {
|
||||
if p.err != nil {
|
||||
return
|
||||
}
|
||||
if pos+1 > len(str) {
|
||||
p.err = errInvalidTimestamp
|
||||
return
|
||||
}
|
||||
if c := str[pos]; c != char && p.err == nil {
|
||||
p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
|
||||
if p.err != nil {
|
||||
return 0
|
||||
}
|
||||
if begin < 0 || end < 0 || begin > end || end > len(str) {
|
||||
p.err = errInvalidTimestamp
|
||||
return 0
|
||||
}
|
||||
result, err := strconv.Atoi(str[begin:end])
|
||||
if err != nil {
|
||||
if p.err == nil {
|
||||
p.err = fmt.Errorf("expected number; got '%v'", str)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// The location cache caches the time zones typically used by the client.
|
||||
type locationCache struct {
|
||||
cache map[int]*time.Location
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// All connections share the same list of timezones. Benchmarking shows that
|
||||
// about 5% speed could be gained by putting the cache in the connection and
|
||||
// losing the mutex, at the cost of a small amount of memory and a somewhat
|
||||
// significant increase in code complexity.
|
||||
var globalLocationCache = newLocationCache()
|
||||
|
||||
func newLocationCache() *locationCache {
|
||||
return &locationCache{cache: make(map[int]*time.Location)}
|
||||
}
|
||||
|
||||
// Returns the cached timezone for the specified offset, creating and caching
|
||||
// it if necessary.
|
||||
func (c *locationCache) getLocation(offset int) *time.Location {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
location, ok := c.cache[offset]
|
||||
if !ok {
|
||||
location = time.FixedZone("", offset)
|
||||
c.cache[offset] = location
|
||||
}
|
||||
|
||||
return location
|
||||
}
|
||||
|
||||
var infinityTsEnabled = false
|
||||
var infinityTsNegative time.Time
|
||||
var infinityTsPositive time.Time
|
||||
|
||||
const (
|
||||
infinityTsEnabledAlready = "pq: infinity timestamp enabled already"
|
||||
infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
|
||||
)
|
||||
|
||||
// EnableInfinityTs controls the handling of Postgres' "-infinity" and
|
||||
// "infinity" "timestamp"s.
|
||||
//
|
||||
// If EnableInfinityTs is not called, "-infinity" and "infinity" will return
|
||||
// []byte("-infinity") and []byte("infinity") respectively, and potentially
|
||||
// cause error "sql: Scan error on column index 0: unsupported driver -> Scan
|
||||
// pair: []uint8 -> *time.Time", when scanning into a time.Time value.
|
||||
//
|
||||
// Once EnableInfinityTs has been called, all connections created using this
|
||||
// driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
|
||||
// "timestamp with time zone" and "date" types to the predefined minimum and
|
||||
// maximum times, respectively. When encoding time.Time values, any time which
|
||||
// equals or precedes the predefined minimum time will be encoded to
|
||||
// "-infinity". Any values at or past the maximum time will similarly be
|
||||
// encoded to "infinity".
|
||||
//
|
||||
// If EnableInfinityTs is called with negative >= positive, it will panic.
|
||||
// Calling EnableInfinityTs after a connection has been established results in
|
||||
// undefined behavior. If EnableInfinityTs is called more than once, it will
|
||||
// panic.
|
||||
func EnableInfinityTs(negative time.Time, positive time.Time) {
|
||||
if infinityTsEnabled {
|
||||
panic(infinityTsEnabledAlready)
|
||||
}
|
||||
if !negative.Before(positive) {
|
||||
panic(infinityTsNegativeMustBeSmaller)
|
||||
}
|
||||
infinityTsEnabled = true
|
||||
infinityTsNegative = negative
|
||||
infinityTsPositive = positive
|
||||
}
|
||||
|
||||
/*
|
||||
* Testing might want to toggle infinityTsEnabled
|
||||
*/
|
||||
func disableInfinityTs() {
|
||||
infinityTsEnabled = false
|
||||
}
|
||||
|
||||
// This is a time function specific to the Postgres default DateStyle
|
||||
// setting ("ISO, MDY"), the only one we currently support. This
|
||||
// accounts for the discrepancies between the parsing available with
|
||||
// time.Parse and the Postgres date formatting quirks.
|
||||
func parseTs(currentLocation *time.Location, str string) interface{} {
|
||||
switch str {
|
||||
case "-infinity":
|
||||
if infinityTsEnabled {
|
||||
return infinityTsNegative
|
||||
}
|
||||
return []byte(str)
|
||||
case "infinity":
|
||||
if infinityTsEnabled {
|
||||
return infinityTsPositive
|
||||
}
|
||||
return []byte(str)
|
||||
}
|
||||
t, err := ParseTimestamp(currentLocation, str)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// ParseTimestamp parses Postgres' text format. It returns a time.Time in
|
||||
// currentLocation iff that time's offset agrees with the offset sent from the
|
||||
// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the
|
||||
// fixed offset offset provided by the Postgres server.
|
||||
func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
|
||||
p := timestampParser{}
|
||||
|
||||
monSep := strings.IndexRune(str, '-')
|
||||
// this is Gregorian year, not ISO Year
|
||||
// In Gregorian system, the year 1 BC is followed by AD 1
|
||||
year := p.mustAtoi(str, 0, monSep)
|
||||
daySep := monSep + 3
|
||||
month := p.mustAtoi(str, monSep+1, daySep)
|
||||
p.expect(str, '-', daySep)
|
||||
timeSep := daySep + 3
|
||||
day := p.mustAtoi(str, daySep+1, timeSep)
|
||||
|
||||
minLen := monSep + len("01-01") + 1
|
||||
|
||||
isBC := strings.HasSuffix(str, " BC")
|
||||
if isBC {
|
||||
minLen += 3
|
||||
}
|
||||
|
||||
var hour, minute, second int
|
||||
if len(str) > minLen {
|
||||
p.expect(str, ' ', timeSep)
|
||||
minSep := timeSep + 3
|
||||
p.expect(str, ':', minSep)
|
||||
hour = p.mustAtoi(str, timeSep+1, minSep)
|
||||
secSep := minSep + 3
|
||||
p.expect(str, ':', secSep)
|
||||
minute = p.mustAtoi(str, minSep+1, secSep)
|
||||
secEnd := secSep + 3
|
||||
second = p.mustAtoi(str, secSep+1, secEnd)
|
||||
}
|
||||
remainderIdx := monSep + len("01-01 00:00:00") + 1
|
||||
// Three optional (but ordered) sections follow: the
|
||||
// fractional seconds, the time zone offset, and the BC
|
||||
// designation. We set them up here and adjust the other
|
||||
// offsets if the preceding sections exist.
|
||||
|
||||
nanoSec := 0
|
||||
tzOff := 0
|
||||
|
||||
if remainderIdx < len(str) && str[remainderIdx] == '.' {
|
||||
fracStart := remainderIdx + 1
|
||||
fracOff := strings.IndexAny(str[fracStart:], "-+Z ")
|
||||
if fracOff < 0 {
|
||||
fracOff = len(str) - fracStart
|
||||
}
|
||||
fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
|
||||
nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
|
||||
|
||||
remainderIdx += fracOff + 1
|
||||
}
|
||||
if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
|
||||
// time zone separator is always '-' or '+' or 'Z' (UTC is +00)
|
||||
var tzSign int
|
||||
switch c := str[tzStart]; c {
|
||||
case '-':
|
||||
tzSign = -1
|
||||
case '+':
|
||||
tzSign = +1
|
||||
default:
|
||||
return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
|
||||
}
|
||||
tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
|
||||
remainderIdx += 3
|
||||
var tzMin, tzSec int
|
||||
if remainderIdx < len(str) && str[remainderIdx] == ':' {
|
||||
tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
|
||||
remainderIdx += 3
|
||||
}
|
||||
if remainderIdx < len(str) && str[remainderIdx] == ':' {
|
||||
tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
|
||||
remainderIdx += 3
|
||||
}
|
||||
tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
|
||||
} else if tzStart < len(str) && str[tzStart] == 'Z' {
|
||||
// time zone Z separator indicates UTC is +00
|
||||
remainderIdx += 1
|
||||
}
|
||||
|
||||
var isoYear int
|
||||
|
||||
if isBC {
|
||||
isoYear = 1 - year
|
||||
remainderIdx += 3
|
||||
} else {
|
||||
isoYear = year
|
||||
}
|
||||
if remainderIdx < len(str) {
|
||||
return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
|
||||
}
|
||||
t := time.Date(isoYear, time.Month(month), day,
|
||||
hour, minute, second, nanoSec,
|
||||
globalLocationCache.getLocation(tzOff))
|
||||
|
||||
if currentLocation != nil {
|
||||
// Set the location of the returned Time based on the session's
|
||||
// TimeZone value, but only if the local time zone database agrees with
|
||||
// the remote database on the offset.
|
||||
lt := t.In(currentLocation)
|
||||
_, newOff := lt.Zone()
|
||||
if newOff == tzOff {
|
||||
t = lt
|
||||
}
|
||||
}
|
||||
|
||||
return t, p.err
|
||||
}
|
||||
|
||||
// formatTs formats t into a format postgres understands.
|
||||
func formatTs(t time.Time) []byte {
|
||||
if infinityTsEnabled {
|
||||
// t <= -infinity : ! (t > -infinity)
|
||||
if !t.After(infinityTsNegative) {
|
||||
return []byte("-infinity")
|
||||
}
|
||||
// t >= infinity : ! (!t < infinity)
|
||||
if !t.Before(infinityTsPositive) {
|
||||
return []byte("infinity")
|
||||
}
|
||||
}
|
||||
return FormatTimestamp(t)
|
||||
}
|
||||
|
||||
// FormatTimestamp formats t into Postgres' text format for timestamps.
|
||||
func FormatTimestamp(t time.Time) []byte {
|
||||
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
|
||||
// minus sign preferred by Go.
|
||||
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
|
||||
bc := false
|
||||
if t.Year() <= 0 {
|
||||
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
|
||||
t = t.AddDate((-t.Year())*2+1, 0, 0)
|
||||
bc = true
|
||||
}
|
||||
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
|
||||
|
||||
_, offset := t.Zone()
|
||||
offset %= 60
|
||||
if offset != 0 {
|
||||
// RFC3339Nano already printed the minus sign
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
b = append(b, ':')
|
||||
if offset < 10 {
|
||||
b = append(b, '0')
|
||||
}
|
||||
b = strconv.AppendInt(b, int64(offset), 10)
|
||||
}
|
||||
|
||||
if bc {
|
||||
b = append(b, " BC"...)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Parse a bytea value received from the server. Both "hex" and the legacy
|
||||
// "escape" format are supported.
|
||||
func parseBytea(s []byte) (result []byte, err error) {
|
||||
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
|
||||
// bytea_output = hex
|
||||
s = s[2:] // trim off leading "\\x"
|
||||
result = make([]byte, hex.DecodedLen(len(s)))
|
||||
_, err := hex.Decode(result, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// bytea_output = escape
|
||||
for len(s) > 0 {
|
||||
if s[0] == '\\' {
|
||||
// escaped '\\'
|
||||
if len(s) >= 2 && s[1] == '\\' {
|
||||
result = append(result, '\\')
|
||||
s = s[2:]
|
||||
continue
|
||||
}
|
||||
|
||||
// '\\' followed by an octal number
|
||||
if len(s) < 4 {
|
||||
return nil, fmt.Errorf("invalid bytea sequence %v", s)
|
||||
}
|
||||
r, err := strconv.ParseUint(string(s[1:4]), 8, 8)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
|
||||
}
|
||||
result = append(result, byte(r))
|
||||
s = s[4:]
|
||||
} else {
|
||||
// We hit an unescaped, raw byte. Try to read in as many as
|
||||
// possible in one go.
|
||||
i := bytes.IndexByte(s, '\\')
|
||||
if i == -1 {
|
||||
result = append(result, s...)
|
||||
break
|
||||
}
|
||||
result = append(result, s[:i]...)
|
||||
s = s[i:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func encodeBytea(serverVersion int, v []byte) (result []byte) {
|
||||
if serverVersion >= 90000 {
|
||||
// Use the hex format if we know that the server supports it
|
||||
result = make([]byte, 2+hex.EncodedLen(len(v)))
|
||||
result[0] = '\\'
|
||||
result[1] = 'x'
|
||||
hex.Encode(result[2:], v)
|
||||
} else {
|
||||
// .. or resort to "escape"
|
||||
for _, b := range v {
|
||||
if b == '\\' {
|
||||
result = append(result, '\\', '\\')
|
||||
} else if b < 0x20 || b > 0x7e {
|
||||
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
|
||||
} else {
|
||||
result = append(result, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NullTime represents a time.Time that may be null. NullTime implements the
|
||||
// sql.Scanner interface so it can be used as a scan destination, similar to
|
||||
// sql.NullString.
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
nt.Time, nt.Valid = value.(time.Time)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
886
third_party/highgo-pq/encode_test.go
vendored
Normal file
886
third_party/highgo-pq/encode_test.go
vendored
Normal file
@@ -0,0 +1,886 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
func TestScanTimestamp(t *testing.T) {
|
||||
var nt NullTime
|
||||
tn := time.Now()
|
||||
nt.Scan(tn)
|
||||
if !nt.Valid {
|
||||
t.Errorf("Expected Valid=false")
|
||||
}
|
||||
if nt.Time != tn {
|
||||
t.Errorf("Time value mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNilTimestamp(t *testing.T) {
|
||||
var nt NullTime
|
||||
nt.Scan(nil)
|
||||
if nt.Valid {
|
||||
t.Errorf("Expected Valid=false")
|
||||
}
|
||||
}
|
||||
|
||||
var timeTests = []struct {
|
||||
str string
|
||||
timeval time.Time
|
||||
}{
|
||||
{"22001-02-03", time.Date(22001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))},
|
||||
{"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))},
|
||||
{"0001-12-31 BC", time.Date(0, time.December, 31, 0, 0, 0, 0, time.FixedZone("", 0))},
|
||||
{"2001-02-03 BC", time.Date(-2000, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.0001", time.Date(2001, time.February, 3, 4, 5, 6, 100000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.001", time.Date(2001, time.February, 3, 4, 5, 6, 1000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.01", time.Date(2001, time.February, 3, 4, 5, 6, 10000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.1", time.Date(2001, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.12", time.Date(2001, time.February, 3, 4, 5, 6, 120000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.123", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.1234", time.Date(2001, time.February, 3, 4, 5, 6, 123400000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.12345", time.Date(2001, time.February, 3, 4, 5, 6, 123450000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.123456", time.Date(2001, time.February, 3, 4, 5, 6, 123456000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.123-07", time.Date(2001, time.February, 3, 4, 5, 6, 123000000,
|
||||
time.FixedZone("", -7*60*60))},
|
||||
{"2001-02-03 04:05:06-07", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", -7*60*60))},
|
||||
{"2001-02-03 04:05:06-07:42", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", -(7*60*60+42*60)))},
|
||||
{"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", -(7*60*60+30*60+9)))},
|
||||
{"2001-02-03 04:05:06+07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", +(7*60*60+30*60+9)))},
|
||||
{"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", 7*60*60))},
|
||||
{"0011-02-03 04:05:06 BC", time.Date(-10, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))},
|
||||
{"0011-02-03 04:05:06.123 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"0011-02-03 04:05:06.123-07 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000,
|
||||
time.FixedZone("", -7*60*60))},
|
||||
{"0001-02-03 04:05:06.123", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"0001-02-03 04:05:06.123 BC", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)},
|
||||
{"0001-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"0002-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)},
|
||||
{"0002-02-03 04:05:06.123 BC", time.Date(-1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"12345-02-03 04:05:06.1", time.Date(12345, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))},
|
||||
{"123456-02-03 04:05:06.1", time.Date(123456, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))},
|
||||
}
|
||||
|
||||
// Test that parsing the string results in the expected value.
|
||||
func TestParseTs(t *testing.T) {
|
||||
for i, tt := range timeTests {
|
||||
val, err := ParseTimestamp(nil, tt.str)
|
||||
if err != nil {
|
||||
t.Errorf("%d: got error: %v", i, err)
|
||||
} else if val.String() != tt.timeval.String() {
|
||||
t.Errorf("%d: expected to parse %q into %q; got %q",
|
||||
i, tt.str, tt.timeval, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var timeErrorTests = []string{
|
||||
"BC",
|
||||
" BC",
|
||||
"2001",
|
||||
"2001-2-03",
|
||||
"2001-02-3",
|
||||
"2001-02-03 ",
|
||||
"2001-02-03 B",
|
||||
"2001-02-03 04",
|
||||
"2001-02-03 04:",
|
||||
"2001-02-03 04:05",
|
||||
"2001-02-03 04:05 B",
|
||||
"2001-02-03 04:05 BC",
|
||||
"2001-02-03 04:05:",
|
||||
"2001-02-03 04:05:6",
|
||||
"2001-02-03 04:05:06 B",
|
||||
"2001-02-03 04:05:06BC",
|
||||
"2001-02-03 04:05:06.123 B",
|
||||
}
|
||||
|
||||
// Test that parsing the string results in an error.
|
||||
func TestParseTsErrors(t *testing.T) {
|
||||
for i, tt := range timeErrorTests {
|
||||
_, err := ParseTimestamp(nil, tt)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected an error from parsing: %v", i, tt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now test that sending the value into the database and parsing it back
|
||||
// returns the same time.Time value.
|
||||
func TestEncodeAndParseTs(t *testing.T) {
|
||||
db, err := openTestConnConninfo("timezone='Etc/UTC'")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
for i, tt := range timeTests {
|
||||
var dbstr string
|
||||
err = db.QueryRow("SELECT ($1::timestamptz)::text", tt.timeval).Scan(&dbstr)
|
||||
if err != nil {
|
||||
t.Errorf("%d: could not send value %q to the database: %s", i, tt.timeval, err)
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := ParseTimestamp(nil, dbstr)
|
||||
if err != nil {
|
||||
t.Errorf("%d: could not parse value %q: %s", i, dbstr, err)
|
||||
continue
|
||||
}
|
||||
val = val.In(tt.timeval.Location())
|
||||
if val.String() != tt.timeval.String() {
|
||||
t.Errorf("%d: expected to parse %q into %q; got %q", i, dbstr, tt.timeval, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var formatTimeTests = []struct {
|
||||
time time.Time
|
||||
expected string
|
||||
}{
|
||||
{time.Time{}, "0001-01-01 00:00:00Z"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"},
|
||||
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"},
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"},
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"},
|
||||
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"},
|
||||
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"},
|
||||
}
|
||||
|
||||
func TestFormatTs(t *testing.T) {
|
||||
for i, tt := range formatTimeTests {
|
||||
val := string(formatTs(tt.time))
|
||||
if val != tt.expected {
|
||||
t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTsBackend(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
var str string
|
||||
err := db.QueryRow("SELECT '2001-02-03T04:05:06.007-08:09:10'::time::text").Scan(&str)
|
||||
if err == nil {
|
||||
t.Fatalf("PostgreSQL is accepting an ISO timestamp input for time")
|
||||
}
|
||||
|
||||
for i, tt := range formatTimeTests {
|
||||
for _, typ := range []string{"date", "time", "timetz", "timestamp", "timestamptz"} {
|
||||
err = db.QueryRow("SELECT $1::"+typ+"::text", tt.time).Scan(&str)
|
||||
if err != nil {
|
||||
t.Errorf("%d: incorrect time format for %v on the backend: %v", i, typ, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeWithoutTimezone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, tc := range []struct {
|
||||
refTime string
|
||||
expectedTime time.Time
|
||||
}{
|
||||
{"11:59:59", time.Date(0, 1, 1, 11, 59, 59, 0, time.UTC)},
|
||||
{"24:00", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
{"24:00:00", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
{"24:00:00.0", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
{"24:00:00.000000", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
} {
|
||||
t.Run(
|
||||
fmt.Sprintf("%s => %s", tc.refTime, tc.expectedTime.Format(time.RFC3339)),
|
||||
func(t *testing.T) {
|
||||
var gotTime time.Time
|
||||
row := tx.QueryRow("select $1::time", tc.refTime)
|
||||
err = row.Scan(&gotTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !tc.expectedTime.Equal(gotTime) {
|
||||
t.Errorf("timestamps not equal: %s != %s", tc.expectedTime, gotTime)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeWithTimezone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, tc := range []struct {
|
||||
refTime string
|
||||
expectedTime time.Time
|
||||
}{
|
||||
{"11:59:59+00:00", time.Date(0, 1, 1, 11, 59, 59, 0, time.UTC)},
|
||||
{"11:59:59+04:00", time.Date(0, 1, 1, 11, 59, 59, 0, time.FixedZone("+04", 4*60*60))},
|
||||
{"11:59:59+04:01:02", time.Date(0, 1, 1, 11, 59, 59, 0, time.FixedZone("+04:01:02", 4*60*60+1*60+2))},
|
||||
{"11:59:59-04:01:02", time.Date(0, 1, 1, 11, 59, 59, 0, time.FixedZone("-04:01:02", -(4*60*60+1*60+2)))},
|
||||
{"24:00+00", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
{"24:00Z", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
{"24:00-04:00", time.Date(0, 1, 2, 0, 0, 0, 0, time.FixedZone("-04", -4*60*60))},
|
||||
{"24:00:00+00", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
{"24:00:00.0+00", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
{"24:00:00.000000+00", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)},
|
||||
} {
|
||||
t.Run(
|
||||
fmt.Sprintf("%s => %s", tc.refTime, tc.expectedTime.Format(time.RFC3339)),
|
||||
func(t *testing.T) {
|
||||
var gotTime time.Time
|
||||
row := tx.QueryRow("select $1::timetz", tc.refTime)
|
||||
err = row.Scan(&gotTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !tc.expectedTime.Equal(gotTime) {
|
||||
t.Errorf("timestamps not equal: %s != %s", tc.expectedTime, gotTime)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampWithTimeZone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// try several different locations, all included in Go's zoneinfo.zip
|
||||
for _, locName := range []string{
|
||||
"UTC",
|
||||
"America/Chicago",
|
||||
"America/New_York",
|
||||
"Australia/Darwin",
|
||||
"Australia/Perth",
|
||||
} {
|
||||
loc, err := time.LoadLocation(locName)
|
||||
if err != nil {
|
||||
t.Logf("Could not load time zone %s - skipping", locName)
|
||||
continue
|
||||
}
|
||||
|
||||
// Postgres timestamps have a resolution of 1 microsecond, so don't
|
||||
// use the full range of the Nanosecond argument
|
||||
refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc)
|
||||
|
||||
for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} {
|
||||
// Switch Postgres's timezone to test different output timestamp formats
|
||||
_, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var gotTime time.Time
|
||||
row := tx.QueryRow("select $1::timestamp with time zone", refTime)
|
||||
err = row.Scan(&gotTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !refTime.Equal(gotTime) {
|
||||
t.Errorf("timestamps not equal: %s != %s", refTime, gotTime)
|
||||
}
|
||||
|
||||
// check that the time zone is set correctly based on TimeZone
|
||||
pgLoc, err := time.LoadLocation(pgTimeZone)
|
||||
if err != nil {
|
||||
t.Logf("Could not load time zone %s - skipping", pgLoc)
|
||||
continue
|
||||
}
|
||||
translated := refTime.In(pgLoc)
|
||||
if translated.String() != gotTime.String() {
|
||||
t.Errorf("timestamps not equal: %s != %s", translated, gotTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampWithOutTimezone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
test := func(ts, pgts string) {
|
||||
r, err := db.Query("SELECT $1::timestamp", pgts)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not run query: %v", err)
|
||||
}
|
||||
|
||||
if !r.Next() {
|
||||
t.Fatal("Expected at least one row")
|
||||
}
|
||||
|
||||
var result time.Time
|
||||
err = r.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error scanning row: %v", err)
|
||||
}
|
||||
|
||||
expected, err := time.Parse(time.RFC3339, ts)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse test time literal: %v", err)
|
||||
}
|
||||
|
||||
if !result.Equal(expected) {
|
||||
t.Fatalf("Expected time to match %v: got mismatch %v",
|
||||
expected, result)
|
||||
}
|
||||
|
||||
if r.Next() {
|
||||
t.Fatal("Expected only one row")
|
||||
}
|
||||
}
|
||||
|
||||
test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00")
|
||||
|
||||
// Test higher precision time
|
||||
test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033")
|
||||
}
|
||||
|
||||
func TestInfinityTimestamp(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
var err error
|
||||
var resultT time.Time
|
||||
|
||||
expectedErrorStrRegexp := regexp.MustCompile(
|
||||
`^sql: Scan error on column index 0(, name "timestamp(tz)?"|): unsupported`)
|
||||
|
||||
type testCases []struct {
|
||||
Query string
|
||||
Param string
|
||||
ExpectedErrorStrRegexp *regexp.Regexp
|
||||
ExpectedVal interface{}
|
||||
}
|
||||
tc := testCases{
|
||||
{"SELECT $1::timestamp", "-infinity", expectedErrorStrRegexp, "-infinity"},
|
||||
{"SELECT $1::timestamptz", "-infinity", expectedErrorStrRegexp, "-infinity"},
|
||||
{"SELECT $1::timestamp", "infinity", expectedErrorStrRegexp, "infinity"},
|
||||
{"SELECT $1::timestamptz", "infinity", expectedErrorStrRegexp, "infinity"},
|
||||
}
|
||||
// try to assert []byte to time.Time
|
||||
for _, q := range tc {
|
||||
err = db.QueryRow(q.Query, q.Param).Scan(&resultT)
|
||||
if err == nil || !q.ExpectedErrorStrRegexp.MatchString(err.Error()) {
|
||||
t.Errorf("Scanning -/+infinity, expected error to match regexp %q, got %q",
|
||||
q.ExpectedErrorStrRegexp, err)
|
||||
}
|
||||
}
|
||||
// yield []byte
|
||||
for _, q := range tc {
|
||||
var resultI interface{}
|
||||
err = db.QueryRow(q.Query, q.Param).Scan(&resultI)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning -/+infinity, expected no error, got %q", err)
|
||||
}
|
||||
result, ok := resultI.([]byte)
|
||||
if !ok {
|
||||
t.Errorf("Scanning -/+infinity, expected []byte, got %#v", resultI)
|
||||
}
|
||||
if string(result) != q.ExpectedVal {
|
||||
t.Errorf("Scanning -/+infinity, expected %q, got %q", q.ExpectedVal, result)
|
||||
}
|
||||
}
|
||||
|
||||
y1500 := time.Date(1500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
y2500 := time.Date(2500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
EnableInfinityTs(y1500, y2500)
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamp", "infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y2500) {
|
||||
t.Errorf("Scanning infinity, expected %q, got %q", y2500, resultT)
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamptz", "infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y2500) {
|
||||
t.Errorf("Scanning Infinity, expected time %q, got %q", y2500, resultT.String())
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamp", "-infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y1500) {
|
||||
t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String())
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamptz", "-infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y1500) {
|
||||
t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String())
|
||||
}
|
||||
|
||||
ym1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
var s string
|
||||
err = db.QueryRow("SELECT $1::timestamp::text", ym1500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "-infinity" {
|
||||
t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s)
|
||||
}
|
||||
err = db.QueryRow("SELECT $1::timestamptz::text", ym1500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "-infinity" {
|
||||
t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s)
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamp::text", y11500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "infinity" {
|
||||
t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s)
|
||||
}
|
||||
err = db.QueryRow("SELECT $1::timestamptz::text", y11500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "infinity" {
|
||||
t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s)
|
||||
}
|
||||
|
||||
disableInfinityTs()
|
||||
|
||||
var panicErrorString string
|
||||
func() {
|
||||
defer func() {
|
||||
panicErrorString, _ = recover().(string)
|
||||
}()
|
||||
EnableInfinityTs(y2500, y1500)
|
||||
}()
|
||||
if panicErrorString != infinityTsNegativeMustBeSmaller {
|
||||
t.Errorf("Expected error, %q, got %q", infinityTsNegativeMustBeSmaller, panicErrorString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringWithNul(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
hello0world := string("hello\x00world")
|
||||
_, err := db.Query("SELECT $1::text", &hello0world)
|
||||
if err == nil {
|
||||
t.Fatal("Postgres accepts a string with nul in it; " +
|
||||
"injection attacks may be plausible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteSliceToText(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := []byte("hello world")
|
||||
row := db.QueryRow("SELECT $1::text", b)
|
||||
|
||||
var result []byte
|
||||
err := row.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(result) != string(b) {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToBytea(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := "hello world"
|
||||
row := db.QueryRow("SELECT $1::bytea", b)
|
||||
|
||||
var result []byte
|
||||
err := row.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(result, []byte(b)) {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextByteSliceToUUID(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := []byte("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
|
||||
row := db.QueryRow("SELECT $1::uuid", b)
|
||||
|
||||
var result string
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22P03" {
|
||||
t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result != string(b) {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBinaryByteSlicetoUUID(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := []byte{'\xa0', '\xee', '\xbc', '\x99',
|
||||
'\x9c', '\x0b',
|
||||
'\x4e', '\xf8',
|
||||
'\xbb', '\x00', '\x6b',
|
||||
'\xb9', '\xbd', '\x38', '\x0a', '\x11'}
|
||||
row := db.QueryRow("SELECT $1::uuid", b)
|
||||
|
||||
var result string
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result != string("a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11") {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
} else {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22021" {
|
||||
t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToUUID(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
s := "a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11"
|
||||
row := db.QueryRow("SELECT $1::uuid", s)
|
||||
|
||||
var result string
|
||||
err := row.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result != s {
|
||||
t.Fatalf("expected %v but got %v", s, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextByteSliceToInt(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
expected := 12345678
|
||||
b := []byte(fmt.Sprintf("%d", expected))
|
||||
row := db.QueryRow("SELECT $1::int", b)
|
||||
|
||||
var result int
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22P03" {
|
||||
t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != expected {
|
||||
t.Fatalf("expected %v but got %v", expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBinaryByteSliceToInt(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
expected := 12345678
|
||||
b := []byte{'\x00', '\xbc', '\x61', '\x4e'}
|
||||
row := db.QueryRow("SELECT $1::int", b)
|
||||
|
||||
var result int
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != expected {
|
||||
t.Fatalf("expected %v but got %v", expected, result)
|
||||
}
|
||||
} else {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22021" {
|
||||
t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextDecodeIntoString(t *testing.T) {
|
||||
input := []byte("hello world")
|
||||
want := string(input)
|
||||
for _, typ := range []oid.Oid{oid.T_char, oid.T_varchar, oid.T_text} {
|
||||
got := decode(¶meterStatus{}, input, typ, formatText)
|
||||
if got != want {
|
||||
t.Errorf("invalid string decoding output for %T(%+v), got %v but expected %v", typ, typ, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteaOutputFormatEncoding(t *testing.T) {
|
||||
input := []byte("\\x\x00\x01\x02\xFF\xFEabcdefg0123")
|
||||
want := []byte("\\x5c78000102fffe6162636465666730313233")
|
||||
got := encode(¶meterStatus{serverVersion: 90000}, input, oid.T_bytea)
|
||||
if !bytes.Equal(want, got) {
|
||||
t.Errorf("invalid hex bytea output, got %v but expected %v", got, want)
|
||||
}
|
||||
|
||||
want = []byte("\\\\x\\000\\001\\002\\377\\376abcdefg0123")
|
||||
got = encode(¶meterStatus{serverVersion: 84000}, input, oid.T_bytea)
|
||||
if !bytes.Equal(want, got) {
|
||||
t.Errorf("invalid escape bytea output, got %v but expected %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteaOutputFormats(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
if getServerVersion(t, db) < 90000 {
|
||||
// skip
|
||||
return
|
||||
}
|
||||
|
||||
testByteaOutputFormat := func(f string, usePrepared bool) {
|
||||
expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08")
|
||||
sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')"
|
||||
|
||||
var data []byte
|
||||
|
||||
// use a txn to avoid relying on getting the same connection
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("SET LOCAL bytea_output TO " + f)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var rows *sql.Rows
|
||||
var stmt *sql.Stmt
|
||||
if usePrepared {
|
||||
stmt, err = txn.Prepare(sqlQuery)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rows, err = stmt.Query()
|
||||
} else {
|
||||
// use Query; QueryRow would hide the actual error
|
||||
rows, err = txn.Query(sqlQuery)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
t.Fatal(rows.Err())
|
||||
}
|
||||
t.Fatal("shouldn't happen")
|
||||
}
|
||||
err = rows.Scan(&data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt != nil {
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(data, expectedData) {
|
||||
t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData)
|
||||
}
|
||||
}
|
||||
|
||||
testByteaOutputFormat("hex", false)
|
||||
testByteaOutputFormat("escape", false)
|
||||
testByteaOutputFormat("hex", true)
|
||||
testByteaOutputFormat("escape", true)
|
||||
}
|
||||
|
||||
func TestAppendEncodedText(t *testing.T) {
|
||||
var buf []byte
|
||||
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, int64(10))
|
||||
buf = append(buf, '\t')
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, 42.0000000001)
|
||||
buf = append(buf, '\t')
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, "hello\tworld")
|
||||
buf = append(buf, '\t')
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, []byte{0, 128, 255})
|
||||
|
||||
if string(buf) != "10\t42.0000000001\thello\\tworld\t\\\\x0080ff" {
|
||||
t.Fatal(string(buf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendEscapedText(t *testing.T) {
|
||||
if esc := appendEscapedText(nil, "hallo\tescape"); string(esc) != "hallo\\tescape" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
if esc := appendEscapedText(nil, "hallo\\tescape\n"); string(esc) != "hallo\\\\tescape\\n" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
if esc := appendEscapedText(nil, "\n\r\t\f"); string(esc) != "\\n\\r\\t\f" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendEscapedTextExistingBuffer(t *testing.T) {
|
||||
buf := []byte("123\t")
|
||||
if esc := appendEscapedText(buf, "hallo\tescape"); string(esc) != "123\thallo\\tescape" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
buf = []byte("123\t")
|
||||
if esc := appendEscapedText(buf, "hallo\\tescape\n"); string(esc) != "123\thallo\\\\tescape\\n" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
buf = []byte("123\t")
|
||||
if esc := appendEscapedText(buf, "\n\r\t\f"); string(esc) != "123\t\\n\\r\\t\f" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
}
|
||||
|
||||
var formatAndParseTimestamp = []struct {
|
||||
time time.Time
|
||||
expected string
|
||||
}{
|
||||
{time.Time{}, "0001-01-01 00:00:00Z"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"},
|
||||
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"},
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"},
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"},
|
||||
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"},
|
||||
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"},
|
||||
}
|
||||
|
||||
func TestFormatAndParseTimestamp(t *testing.T) {
|
||||
for _, val := range formatAndParseTimestamp {
|
||||
formattedTime := FormatTimestamp(val.time)
|
||||
parsedTime, err := ParseTimestamp(nil, string(formattedTime))
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("invalid parsing, err: %v", err.Error())
|
||||
}
|
||||
|
||||
if val.time.UTC() != parsedTime.UTC() {
|
||||
t.Errorf("invalid parsing from formatted timestamp, got %v; expected %v", parsedTime.String(), val.time.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendEscapedText(b *testing.B) {
|
||||
longString := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
longString += "123456789\n"
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
appendEscapedText(nil, longString)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendEscapedTextNoEscape(b *testing.B) {
|
||||
longString := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
longString += "1234567890"
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
appendEscapedText(nil, longString)
|
||||
}
|
||||
}
|
||||
523
third_party/highgo-pq/error.go
vendored
Normal file
523
third_party/highgo-pq/error.go
vendored
Normal file
@@ -0,0 +1,523 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Error severities
|
||||
const (
|
||||
Efatal = "FATAL"
|
||||
Epanic = "PANIC"
|
||||
Ewarning = "WARNING"
|
||||
Enotice = "NOTICE"
|
||||
Edebug = "DEBUG"
|
||||
Einfo = "INFO"
|
||||
Elog = "LOG"
|
||||
)
|
||||
|
||||
// Error represents an error communicating with the server.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/protocol-error-fields.html for details of the fields
|
||||
type Error struct {
|
||||
Severity string
|
||||
Code ErrorCode
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position string
|
||||
InternalPosition string
|
||||
InternalQuery string
|
||||
Where string
|
||||
Schema string
|
||||
Table string
|
||||
Column string
|
||||
DataTypeName string
|
||||
Constraint string
|
||||
File string
|
||||
Line string
|
||||
Routine string
|
||||
}
|
||||
|
||||
// ErrorCode is a five-character error code.
|
||||
type ErrorCode string
|
||||
|
||||
// Name returns a more human friendly rendering of the error code, namely the
|
||||
// "condition name".
|
||||
//
|
||||
// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for
|
||||
// details.
|
||||
func (ec ErrorCode) Name() string {
|
||||
return errorCodeNames[ec]
|
||||
}
|
||||
|
||||
// ErrorClass is only the class part of an error code.
|
||||
type ErrorClass string
|
||||
|
||||
// Name returns the condition name of an error class. It is equivalent to the
|
||||
// condition name of the "standard" error code (i.e. the one having the last
|
||||
// three characters "000").
|
||||
func (ec ErrorClass) Name() string {
|
||||
return errorCodeNames[ErrorCode(ec+"000")]
|
||||
}
|
||||
|
||||
// Class returns the error class, e.g. "28".
|
||||
//
|
||||
// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for
|
||||
// details.
|
||||
func (ec ErrorCode) Class() ErrorClass {
|
||||
return ErrorClass(ec[0:2])
|
||||
}
|
||||
|
||||
// errorCodeNames is a mapping between the five-character error codes and the
|
||||
// human readable "condition names". It is derived from the list at
|
||||
// http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html
|
||||
var errorCodeNames = map[ErrorCode]string{
|
||||
// Class 00 - Successful Completion
|
||||
"00000": "successful_completion",
|
||||
// Class 01 - Warning
|
||||
"01000": "warning",
|
||||
"0100C": "dynamic_result_sets_returned",
|
||||
"01008": "implicit_zero_bit_padding",
|
||||
"01003": "null_value_eliminated_in_set_function",
|
||||
"01007": "privilege_not_granted",
|
||||
"01006": "privilege_not_revoked",
|
||||
"01004": "string_data_right_truncation",
|
||||
"01P01": "deprecated_feature",
|
||||
// Class 02 - No Data (this is also a warning class per the SQL standard)
|
||||
"02000": "no_data",
|
||||
"02001": "no_additional_dynamic_result_sets_returned",
|
||||
// Class 03 - SQL Statement Not Yet Complete
|
||||
"03000": "sql_statement_not_yet_complete",
|
||||
// Class 08 - Connection Exception
|
||||
"08000": "connection_exception",
|
||||
"08003": "connection_does_not_exist",
|
||||
"08006": "connection_failure",
|
||||
"08001": "sqlclient_unable_to_establish_sqlconnection",
|
||||
"08004": "sqlserver_rejected_establishment_of_sqlconnection",
|
||||
"08007": "transaction_resolution_unknown",
|
||||
"08P01": "protocol_violation",
|
||||
// Class 09 - Triggered Action Exception
|
||||
"09000": "triggered_action_exception",
|
||||
// Class 0A - Feature Not Supported
|
||||
"0A000": "feature_not_supported",
|
||||
// Class 0B - Invalid Transaction Initiation
|
||||
"0B000": "invalid_transaction_initiation",
|
||||
// Class 0F - Locator Exception
|
||||
"0F000": "locator_exception",
|
||||
"0F001": "invalid_locator_specification",
|
||||
// Class 0L - Invalid Grantor
|
||||
"0L000": "invalid_grantor",
|
||||
"0LP01": "invalid_grant_operation",
|
||||
// Class 0P - Invalid Role Specification
|
||||
"0P000": "invalid_role_specification",
|
||||
// Class 0Z - Diagnostics Exception
|
||||
"0Z000": "diagnostics_exception",
|
||||
"0Z002": "stacked_diagnostics_accessed_without_active_handler",
|
||||
// Class 20 - Case Not Found
|
||||
"20000": "case_not_found",
|
||||
// Class 21 - Cardinality Violation
|
||||
"21000": "cardinality_violation",
|
||||
// Class 22 - Data Exception
|
||||
"22000": "data_exception",
|
||||
"2202E": "array_subscript_error",
|
||||
"22021": "character_not_in_repertoire",
|
||||
"22008": "datetime_field_overflow",
|
||||
"22012": "division_by_zero",
|
||||
"22005": "error_in_assignment",
|
||||
"2200B": "escape_character_conflict",
|
||||
"22022": "indicator_overflow",
|
||||
"22015": "interval_field_overflow",
|
||||
"2201E": "invalid_argument_for_logarithm",
|
||||
"22014": "invalid_argument_for_ntile_function",
|
||||
"22016": "invalid_argument_for_nth_value_function",
|
||||
"2201F": "invalid_argument_for_power_function",
|
||||
"2201G": "invalid_argument_for_width_bucket_function",
|
||||
"22018": "invalid_character_value_for_cast",
|
||||
"22007": "invalid_datetime_format",
|
||||
"22019": "invalid_escape_character",
|
||||
"2200D": "invalid_escape_octet",
|
||||
"22025": "invalid_escape_sequence",
|
||||
"22P06": "nonstandard_use_of_escape_character",
|
||||
"22010": "invalid_indicator_parameter_value",
|
||||
"22023": "invalid_parameter_value",
|
||||
"2201B": "invalid_regular_expression",
|
||||
"2201W": "invalid_row_count_in_limit_clause",
|
||||
"2201X": "invalid_row_count_in_result_offset_clause",
|
||||
"22009": "invalid_time_zone_displacement_value",
|
||||
"2200C": "invalid_use_of_escape_character",
|
||||
"2200G": "most_specific_type_mismatch",
|
||||
"22004": "null_value_not_allowed",
|
||||
"22002": "null_value_no_indicator_parameter",
|
||||
"22003": "numeric_value_out_of_range",
|
||||
"2200H": "sequence_generator_limit_exceeded",
|
||||
"22026": "string_data_length_mismatch",
|
||||
"22001": "string_data_right_truncation",
|
||||
"22011": "substring_error",
|
||||
"22027": "trim_error",
|
||||
"22024": "unterminated_c_string",
|
||||
"2200F": "zero_length_character_string",
|
||||
"22P01": "floating_point_exception",
|
||||
"22P02": "invalid_text_representation",
|
||||
"22P03": "invalid_binary_representation",
|
||||
"22P04": "bad_copy_file_format",
|
||||
"22P05": "untranslatable_character",
|
||||
"2200L": "not_an_xml_document",
|
||||
"2200M": "invalid_xml_document",
|
||||
"2200N": "invalid_xml_content",
|
||||
"2200S": "invalid_xml_comment",
|
||||
"2200T": "invalid_xml_processing_instruction",
|
||||
// Class 23 - Integrity Constraint Violation
|
||||
"23000": "integrity_constraint_violation",
|
||||
"23001": "restrict_violation",
|
||||
"23502": "not_null_violation",
|
||||
"23503": "foreign_key_violation",
|
||||
"23505": "unique_violation",
|
||||
"23514": "check_violation",
|
||||
"23P01": "exclusion_violation",
|
||||
// Class 24 - Invalid Cursor State
|
||||
"24000": "invalid_cursor_state",
|
||||
// Class 25 - Invalid Transaction State
|
||||
"25000": "invalid_transaction_state",
|
||||
"25001": "active_sql_transaction",
|
||||
"25002": "branch_transaction_already_active",
|
||||
"25008": "held_cursor_requires_same_isolation_level",
|
||||
"25003": "inappropriate_access_mode_for_branch_transaction",
|
||||
"25004": "inappropriate_isolation_level_for_branch_transaction",
|
||||
"25005": "no_active_sql_transaction_for_branch_transaction",
|
||||
"25006": "read_only_sql_transaction",
|
||||
"25007": "schema_and_data_statement_mixing_not_supported",
|
||||
"25P01": "no_active_sql_transaction",
|
||||
"25P02": "in_failed_sql_transaction",
|
||||
// Class 26 - Invalid SQL Statement Name
|
||||
"26000": "invalid_sql_statement_name",
|
||||
// Class 27 - Triggered Data Change Violation
|
||||
"27000": "triggered_data_change_violation",
|
||||
// Class 28 - Invalid Authorization Specification
|
||||
"28000": "invalid_authorization_specification",
|
||||
"28P01": "invalid_password",
|
||||
// Class 2B - Dependent Privilege Descriptors Still Exist
|
||||
"2B000": "dependent_privilege_descriptors_still_exist",
|
||||
"2BP01": "dependent_objects_still_exist",
|
||||
// Class 2D - Invalid Transaction Termination
|
||||
"2D000": "invalid_transaction_termination",
|
||||
// Class 2F - SQL Routine Exception
|
||||
"2F000": "sql_routine_exception",
|
||||
"2F005": "function_executed_no_return_statement",
|
||||
"2F002": "modifying_sql_data_not_permitted",
|
||||
"2F003": "prohibited_sql_statement_attempted",
|
||||
"2F004": "reading_sql_data_not_permitted",
|
||||
// Class 34 - Invalid Cursor Name
|
||||
"34000": "invalid_cursor_name",
|
||||
// Class 38 - External Routine Exception
|
||||
"38000": "external_routine_exception",
|
||||
"38001": "containing_sql_not_permitted",
|
||||
"38002": "modifying_sql_data_not_permitted",
|
||||
"38003": "prohibited_sql_statement_attempted",
|
||||
"38004": "reading_sql_data_not_permitted",
|
||||
// Class 39 - External Routine Invocation Exception
|
||||
"39000": "external_routine_invocation_exception",
|
||||
"39001": "invalid_sqlstate_returned",
|
||||
"39004": "null_value_not_allowed",
|
||||
"39P01": "trigger_protocol_violated",
|
||||
"39P02": "srf_protocol_violated",
|
||||
// Class 3B - Savepoint Exception
|
||||
"3B000": "savepoint_exception",
|
||||
"3B001": "invalid_savepoint_specification",
|
||||
// Class 3D - Invalid Catalog Name
|
||||
"3D000": "invalid_catalog_name",
|
||||
// Class 3F - Invalid Schema Name
|
||||
"3F000": "invalid_schema_name",
|
||||
// Class 40 - Transaction Rollback
|
||||
"40000": "transaction_rollback",
|
||||
"40002": "transaction_integrity_constraint_violation",
|
||||
"40001": "serialization_failure",
|
||||
"40003": "statement_completion_unknown",
|
||||
"40P01": "deadlock_detected",
|
||||
// Class 42 - Syntax Error or Access Rule Violation
|
||||
"42000": "syntax_error_or_access_rule_violation",
|
||||
"42601": "syntax_error",
|
||||
"42501": "insufficient_privilege",
|
||||
"42846": "cannot_coerce",
|
||||
"42803": "grouping_error",
|
||||
"42P20": "windowing_error",
|
||||
"42P19": "invalid_recursion",
|
||||
"42830": "invalid_foreign_key",
|
||||
"42602": "invalid_name",
|
||||
"42622": "name_too_long",
|
||||
"42939": "reserved_name",
|
||||
"42804": "datatype_mismatch",
|
||||
"42P18": "indeterminate_datatype",
|
||||
"42P21": "collation_mismatch",
|
||||
"42P22": "indeterminate_collation",
|
||||
"42809": "wrong_object_type",
|
||||
"42703": "undefined_column",
|
||||
"42883": "undefined_function",
|
||||
"42P01": "undefined_table",
|
||||
"42P02": "undefined_parameter",
|
||||
"42704": "undefined_object",
|
||||
"42701": "duplicate_column",
|
||||
"42P03": "duplicate_cursor",
|
||||
"42P04": "duplicate_database",
|
||||
"42723": "duplicate_function",
|
||||
"42P05": "duplicate_prepared_statement",
|
||||
"42P06": "duplicate_schema",
|
||||
"42P07": "duplicate_table",
|
||||
"42712": "duplicate_alias",
|
||||
"42710": "duplicate_object",
|
||||
"42702": "ambiguous_column",
|
||||
"42725": "ambiguous_function",
|
||||
"42P08": "ambiguous_parameter",
|
||||
"42P09": "ambiguous_alias",
|
||||
"42P10": "invalid_column_reference",
|
||||
"42611": "invalid_column_definition",
|
||||
"42P11": "invalid_cursor_definition",
|
||||
"42P12": "invalid_database_definition",
|
||||
"42P13": "invalid_function_definition",
|
||||
"42P14": "invalid_prepared_statement_definition",
|
||||
"42P15": "invalid_schema_definition",
|
||||
"42P16": "invalid_table_definition",
|
||||
"42P17": "invalid_object_definition",
|
||||
// Class 44 - WITH CHECK OPTION Violation
|
||||
"44000": "with_check_option_violation",
|
||||
// Class 53 - Insufficient Resources
|
||||
"53000": "insufficient_resources",
|
||||
"53100": "disk_full",
|
||||
"53200": "out_of_memory",
|
||||
"53300": "too_many_connections",
|
||||
"53400": "configuration_limit_exceeded",
|
||||
// Class 54 - Program Limit Exceeded
|
||||
"54000": "program_limit_exceeded",
|
||||
"54001": "statement_too_complex",
|
||||
"54011": "too_many_columns",
|
||||
"54023": "too_many_arguments",
|
||||
// Class 55 - Object Not In Prerequisite State
|
||||
"55000": "object_not_in_prerequisite_state",
|
||||
"55006": "object_in_use",
|
||||
"55P02": "cant_change_runtime_param",
|
||||
"55P03": "lock_not_available",
|
||||
// Class 57 - Operator Intervention
|
||||
"57000": "operator_intervention",
|
||||
"57014": "query_canceled",
|
||||
"57P01": "admin_shutdown",
|
||||
"57P02": "crash_shutdown",
|
||||
"57P03": "cannot_connect_now",
|
||||
"57P04": "database_dropped",
|
||||
// Class 58 - System Error (errors external to PostgreSQL itself)
|
||||
"58000": "system_error",
|
||||
"58030": "io_error",
|
||||
"58P01": "undefined_file",
|
||||
"58P02": "duplicate_file",
|
||||
// Class F0 - Configuration File Error
|
||||
"F0000": "config_file_error",
|
||||
"F0001": "lock_file_exists",
|
||||
// Class HV - Foreign Data Wrapper Error (SQL/MED)
|
||||
"HV000": "fdw_error",
|
||||
"HV005": "fdw_column_name_not_found",
|
||||
"HV002": "fdw_dynamic_parameter_value_needed",
|
||||
"HV010": "fdw_function_sequence_error",
|
||||
"HV021": "fdw_inconsistent_descriptor_information",
|
||||
"HV024": "fdw_invalid_attribute_value",
|
||||
"HV007": "fdw_invalid_column_name",
|
||||
"HV008": "fdw_invalid_column_number",
|
||||
"HV004": "fdw_invalid_data_type",
|
||||
"HV006": "fdw_invalid_data_type_descriptors",
|
||||
"HV091": "fdw_invalid_descriptor_field_identifier",
|
||||
"HV00B": "fdw_invalid_handle",
|
||||
"HV00C": "fdw_invalid_option_index",
|
||||
"HV00D": "fdw_invalid_option_name",
|
||||
"HV090": "fdw_invalid_string_length_or_buffer_length",
|
||||
"HV00A": "fdw_invalid_string_format",
|
||||
"HV009": "fdw_invalid_use_of_null_pointer",
|
||||
"HV014": "fdw_too_many_handles",
|
||||
"HV001": "fdw_out_of_memory",
|
||||
"HV00P": "fdw_no_schemas",
|
||||
"HV00J": "fdw_option_name_not_found",
|
||||
"HV00K": "fdw_reply_handle",
|
||||
"HV00Q": "fdw_schema_not_found",
|
||||
"HV00R": "fdw_table_not_found",
|
||||
"HV00L": "fdw_unable_to_create_execution",
|
||||
"HV00M": "fdw_unable_to_create_reply",
|
||||
"HV00N": "fdw_unable_to_establish_connection",
|
||||
// Class P0 - PL/pgSQL Error
|
||||
"P0000": "plpgsql_error",
|
||||
"P0001": "raise_exception",
|
||||
"P0002": "no_data_found",
|
||||
"P0003": "too_many_rows",
|
||||
// Class XX - Internal Error
|
||||
"XX000": "internal_error",
|
||||
"XX001": "data_corrupted",
|
||||
"XX002": "index_corrupted",
|
||||
}
|
||||
|
||||
func parseError(r *readBuf) *Error {
|
||||
err := new(Error)
|
||||
for t := r.byte(); t != 0; t = r.byte() {
|
||||
msg := r.string()
|
||||
switch t {
|
||||
case 'S':
|
||||
err.Severity = msg
|
||||
case 'C':
|
||||
err.Code = ErrorCode(msg)
|
||||
case 'M':
|
||||
err.Message = msg
|
||||
case 'D':
|
||||
err.Detail = msg
|
||||
case 'H':
|
||||
err.Hint = msg
|
||||
case 'P':
|
||||
err.Position = msg
|
||||
case 'p':
|
||||
err.InternalPosition = msg
|
||||
case 'q':
|
||||
err.InternalQuery = msg
|
||||
case 'W':
|
||||
err.Where = msg
|
||||
case 's':
|
||||
err.Schema = msg
|
||||
case 't':
|
||||
err.Table = msg
|
||||
case 'c':
|
||||
err.Column = msg
|
||||
case 'd':
|
||||
err.DataTypeName = msg
|
||||
case 'n':
|
||||
err.Constraint = msg
|
||||
case 'F':
|
||||
err.File = msg
|
||||
case 'L':
|
||||
err.Line = msg
|
||||
case 'R':
|
||||
err.Routine = msg
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Fatal returns true if the Error Severity is fatal.
|
||||
func (err *Error) Fatal() bool {
|
||||
return err.Severity == Efatal
|
||||
}
|
||||
|
||||
// SQLState returns the SQLState of the error.
|
||||
func (err *Error) SQLState() string {
|
||||
return string(err.Code)
|
||||
}
|
||||
|
||||
// Get implements the legacy PGError interface. New code should use the fields
|
||||
// of the Error struct directly.
|
||||
func (err *Error) Get(k byte) (v string) {
|
||||
switch k {
|
||||
case 'S':
|
||||
return err.Severity
|
||||
case 'C':
|
||||
return string(err.Code)
|
||||
case 'M':
|
||||
return err.Message
|
||||
case 'D':
|
||||
return err.Detail
|
||||
case 'H':
|
||||
return err.Hint
|
||||
case 'P':
|
||||
return err.Position
|
||||
case 'p':
|
||||
return err.InternalPosition
|
||||
case 'q':
|
||||
return err.InternalQuery
|
||||
case 'W':
|
||||
return err.Where
|
||||
case 's':
|
||||
return err.Schema
|
||||
case 't':
|
||||
return err.Table
|
||||
case 'c':
|
||||
return err.Column
|
||||
case 'd':
|
||||
return err.DataTypeName
|
||||
case 'n':
|
||||
return err.Constraint
|
||||
case 'F':
|
||||
return err.File
|
||||
case 'L':
|
||||
return err.Line
|
||||
case 'R':
|
||||
return err.Routine
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (err *Error) Error() string {
|
||||
return "pq: " + err.Message
|
||||
}
|
||||
|
||||
// PGError is an interface used by previous versions of pq. It is provided
|
||||
// only to support legacy code. New code should use the Error type.
|
||||
type PGError interface {
|
||||
Error() string
|
||||
Fatal() bool
|
||||
Get(k byte) (v string)
|
||||
}
|
||||
|
||||
func errorf(s string, args ...interface{}) {
|
||||
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
|
||||
}
|
||||
|
||||
// TODO(ainar-g) Rename to errorf after removing panics.
|
||||
func fmterrorf(s string, args ...interface{}) error {
|
||||
return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))
|
||||
}
|
||||
|
||||
func errRecoverNoErrBadConn(err *error) {
|
||||
e := recover()
|
||||
if e == nil {
|
||||
// Do nothing
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
*err, ok = e.(error)
|
||||
if !ok {
|
||||
*err = fmt.Errorf("pq: unexpected error: %#v", e)
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) errRecover(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case runtime.Error:
|
||||
cn.err.set(driver.ErrBadConn)
|
||||
panic(v)
|
||||
case *Error:
|
||||
if v.Fatal() {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
case *net.OpError:
|
||||
cn.err.set(driver.ErrBadConn)
|
||||
*err = v
|
||||
case *safeRetryError:
|
||||
cn.err.set(driver.ErrBadConn)
|
||||
*err = driver.ErrBadConn
|
||||
case error:
|
||||
if v == io.EOF || v.Error() == "remote error: handshake failure" {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
|
||||
default:
|
||||
cn.err.set(driver.ErrBadConn)
|
||||
panic(fmt.Sprintf("unknown error: %#v", e))
|
||||
}
|
||||
|
||||
// Any time we return ErrBadConn, we need to remember it since *Tx doesn't
|
||||
// mark the connection bad in database/sql.
|
||||
if *err == driver.ErrBadConn {
|
||||
cn.err.set(driver.ErrBadConn)
|
||||
}
|
||||
}
|
||||
98
third_party/highgo-pq/example/listen/doc.go
vendored
Normal file
98
third_party/highgo-pq/example/listen/doc.go
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
/*
|
||||
|
||||
Package listen is a self-contained Go program which uses the LISTEN / NOTIFY
|
||||
mechanism to avoid polling the database while waiting for more work to arrive.
|
||||
|
||||
//
|
||||
// You can see the program in action by defining a function similar to
|
||||
// the following:
|
||||
//
|
||||
// CREATE OR REPLACE FUNCTION public.get_work()
|
||||
// RETURNS bigint
|
||||
// LANGUAGE sql
|
||||
// AS $$
|
||||
// SELECT CASE WHEN random() >= 0.2 THEN int8 '1' END
|
||||
// $$
|
||||
// ;
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func doWork(db *sql.DB, work int64) {
|
||||
// work here
|
||||
}
|
||||
|
||||
func getWork(db *sql.DB) {
|
||||
for {
|
||||
// get work from the database here
|
||||
var work sql.NullInt64
|
||||
err := db.QueryRow("SELECT get_work()").Scan(&work)
|
||||
if err != nil {
|
||||
fmt.Println("call to get_work() failed: ", err)
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
if !work.Valid {
|
||||
// no more work to do
|
||||
fmt.Println("ran out of work")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("starting work on ", work.Int64)
|
||||
go doWork(db, work.Int64)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForNotification(l *pq.Listener) {
|
||||
select {
|
||||
case <-l.Notify:
|
||||
fmt.Println("received notification, new work available")
|
||||
case <-time.After(90 * time.Second):
|
||||
go l.Ping()
|
||||
// Check if there's more work available, just in case it takes
|
||||
// a while for the Listener to notice connection loss and
|
||||
// reconnect.
|
||||
fmt.Println("received no work for 90 seconds, checking for new work")
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
var conninfo string = ""
|
||||
|
||||
db, err := sql.Open("postgres", conninfo)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
reportProblem := func(ev pq.ListenerEventType, err error) {
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
minReconn := 10 * time.Second
|
||||
maxReconn := time.Minute
|
||||
listener := pq.NewListener(conninfo, minReconn, maxReconn, reportProblem)
|
||||
err = listener.Listen("getwork")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println("entering main loop")
|
||||
for {
|
||||
// process all available work before waiting for notifications
|
||||
getWork(db)
|
||||
waitForNotification(listener)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
*/
|
||||
package listen
|
||||
3
third_party/highgo-pq/go.mod
vendored
Normal file
3
third_party/highgo-pq/go.mod
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
module github.com/highgo/pq-sm3
|
||||
|
||||
go 1.13
|
||||
352
third_party/highgo-pq/go18_test.go
vendored
Normal file
352
third_party/highgo-pq/go18_test.go
vendored
Normal file
@@ -0,0 +1,352 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMultipleSimpleQuery(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query("select 1; set time zone default; select 2; select 3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var i int
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Fatalf("expected 1, got %d", i)
|
||||
}
|
||||
}
|
||||
if !rows.NextResultSet() {
|
||||
t.Fatal("expected more result sets", rows.Err())
|
||||
}
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 2 {
|
||||
t.Fatalf("expected 2, got %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure that if we ignore a result we can still query.
|
||||
|
||||
rows, err = db.Query("select 4; select 5")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 4 {
|
||||
t.Fatalf("expected 4, got %d", i)
|
||||
}
|
||||
}
|
||||
if !rows.NextResultSet() {
|
||||
t.Fatal("expected more result sets", rows.Err())
|
||||
}
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 5 {
|
||||
t.Fatalf("expected 5, got %d", i)
|
||||
}
|
||||
}
|
||||
if rows.NextResultSet() {
|
||||
t.Fatal("unexpected result set")
|
||||
}
|
||||
}
|
||||
|
||||
const contextRaceIterations = 100
|
||||
|
||||
const cancelErrorCode ErrorCode = "57014"
|
||||
|
||||
func TestContextCancelExec(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Delay execution for just a bit until db.ExecContext has begun.
|
||||
defer time.AfterFunc(time.Millisecond*10, cancel).Stop()
|
||||
|
||||
// Not canceled until after the exec has started.
|
||||
if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// Context is already canceled, so error should come before execution.
|
||||
if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if err.Error() != "context canceled" {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
for i := 0; i < contextRaceIterations; i++ {
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if _, err := db.ExecContext(ctx, "select 1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := db.Exec("select 1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextCancelQuery(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Delay execution for just a bit until db.QueryContext has begun.
|
||||
defer time.AfterFunc(time.Millisecond*10, cancel).Stop()
|
||||
|
||||
// Not canceled until after the exec has started.
|
||||
if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// Context is already canceled, so error should come before execution.
|
||||
if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if err.Error() != "context canceled" {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
for i := 0; i < contextRaceIterations; i++ {
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rows, err := db.QueryContext(ctx, "select 1")
|
||||
cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := rows.Close(); err != nil && err != driver.ErrBadConn && err != context.Canceled {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if rows, err := db.Query("select 1"); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := rows.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue617 tests that a failed query in QueryContext doesn't lead to a
|
||||
// goroutine leak.
|
||||
func TestIssue617(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
const N = 10
|
||||
|
||||
numGoroutineStart := runtime.NumGoroutine()
|
||||
for i := 0; i < N; i++ {
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := db.QueryContext(ctx, `SELECT * FROM DOESNOTEXIST`)
|
||||
pqErr, _ := err.(*Error)
|
||||
// Expecting "pq: relation \"doesnotexist\" does not exist" error.
|
||||
if err == nil || pqErr == nil || pqErr.Code != "42P01" {
|
||||
t.Fatalf("expected undefined table error, got %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Give time for goroutines to terminate
|
||||
delayTime := time.Millisecond * 50
|
||||
waitTime := time.Second
|
||||
iterations := int(waitTime / delayTime)
|
||||
|
||||
var numGoroutineFinish int
|
||||
for i := 0; i < iterations; i++ {
|
||||
time.Sleep(delayTime)
|
||||
|
||||
numGoroutineFinish = runtime.NumGoroutine()
|
||||
|
||||
// We use N/2 and not N because the GC and other actors may increase or
|
||||
// decrease the number of goroutines.
|
||||
if numGoroutineFinish-numGoroutineStart < N/2 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish)
|
||||
}
|
||||
|
||||
func TestContextCancelBegin(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delay execution for just a bit until tx.Exec has begun.
|
||||
defer time.AfterFunc(time.Millisecond*10, cancel).Stop()
|
||||
|
||||
// Not canceled until after the exec has started.
|
||||
if _, err := tx.Exec("select pg_sleep(1)"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// Transaction is canceled, so expect an error.
|
||||
if _, err := tx.Query("select pg_sleep(1)"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if err != sql.ErrTxDone {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// Context is canceled, so cannot begin a transaction.
|
||||
if _, err := db.BeginTx(ctx, nil); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if err.Error() != "context canceled" {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
for i := 0; i < contextRaceIterations; i++ {
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err, pgErr := tx.Rollback(), (*Error)(nil); err != nil &&
|
||||
!(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) &&
|
||||
err != sql.ErrTxDone && err != driver.ErrBadConn && err != context.Canceled {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if tx, err := db.Begin(); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxOptions(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
level sql.IsolationLevel
|
||||
isolation string
|
||||
}{
|
||||
{
|
||||
level: sql.LevelDefault,
|
||||
isolation: "",
|
||||
},
|
||||
{
|
||||
level: sql.LevelReadUncommitted,
|
||||
isolation: "read uncommitted",
|
||||
},
|
||||
{
|
||||
level: sql.LevelReadCommitted,
|
||||
isolation: "read committed",
|
||||
},
|
||||
{
|
||||
level: sql.LevelRepeatableRead,
|
||||
isolation: "repeatable read",
|
||||
},
|
||||
{
|
||||
level: sql.LevelSerializable,
|
||||
isolation: "serializable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
for _, ro := range []bool{true, false} {
|
||||
tx, err := db.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: test.level,
|
||||
ReadOnly: ro,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var isolation string
|
||||
err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test.isolation != "" && isolation != test.isolation {
|
||||
t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation)
|
||||
}
|
||||
|
||||
var isRO string
|
||||
err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if ro != (isRO == "on") {
|
||||
t.Errorf("read/[write,only] not set: %t != %s for level %s",
|
||||
ro, isRO, test.isolation)
|
||||
}
|
||||
|
||||
tx.Rollback()
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelLinearizable,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected LevelLinearizable to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "isolation level not supported") {
|
||||
t.Errorf("Expected error to mention isolation level, got %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorSQLState(t *testing.T) {
|
||||
r := readBuf([]byte{67, 52, 48, 48, 48, 49, 0, 0}) // 40001
|
||||
err := parseError(&r)
|
||||
var sqlErr errWithSQLState
|
||||
if !errors.As(err, &sqlErr) {
|
||||
t.Fatal("SQLState interface not satisfied")
|
||||
}
|
||||
if state := err.SQLState(); state != "40001" {
|
||||
t.Fatalf("unexpected SQL state %v", state)
|
||||
}
|
||||
}
|
||||
|
||||
type errWithSQLState interface {
|
||||
SQLState() string
|
||||
}
|
||||
99
third_party/highgo-pq/go19_test.go
vendored
Normal file
99
third_party/highgo-pq/go19_test.go
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package pq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
if _, ok := reflect.TypeOf(db).MethodByName("Conn"); !ok {
|
||||
t.Skipf("Conn method undefined on type %T, skipping test (requires at least go1.9)", db)
|
||||
}
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
t.Fatal("expected Ping to succeed")
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// grab a connection
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// start a transaction and read backend pid of our connection
|
||||
tx, err := conn.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelDefault,
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := tx.Query("SELECT pg_backend_pid()")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// read the pid from result
|
||||
var pid int
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&pid); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Fail the transaction and make sure we can still ping.
|
||||
if _, err := tx.Query("INVALID SQL"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tx.Rollback(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// kill the process which handles our connection and test if the ping fails
|
||||
if _, err := db.Exec("SELECT pg_terminate_backend($1)", pid); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := conn.PingContext(ctx); err != driver.ErrBadConn {
|
||||
t.Fatalf("expected error %s, instead got %s", driver.ErrBadConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitInFailedTransactionWithCancelContext(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
txn, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rows, err := txn.Query("SELECT error")
|
||||
if err == nil {
|
||||
rows.Close()
|
||||
t.Fatal("expected failure")
|
||||
}
|
||||
err = txn.Commit()
|
||||
if err != ErrInFailedTransaction {
|
||||
t.Fatalf("expected ErrInFailedTransaction; got %#v", err)
|
||||
}
|
||||
}
|
||||
118
third_party/highgo-pq/hstore/hstore.go
vendored
Normal file
118
third_party/highgo-pq/hstore/hstore.go
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
package hstore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Hstore is a wrapper for transferring Hstore values back and forth easily.
|
||||
type Hstore struct {
|
||||
Map map[string]sql.NullString
|
||||
}
|
||||
|
||||
// escapes and quotes hstore keys/values
|
||||
// s should be a sql.NullString or string
|
||||
func hQuote(s interface{}) string {
|
||||
var str string
|
||||
switch v := s.(type) {
|
||||
case sql.NullString:
|
||||
if !v.Valid {
|
||||
return "NULL"
|
||||
}
|
||||
str = v.String
|
||||
case string:
|
||||
str = v
|
||||
default:
|
||||
panic("not a string or sql.NullString")
|
||||
}
|
||||
|
||||
str = strings.Replace(str, "\\", "\\\\", -1)
|
||||
return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"`
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
//
|
||||
// Note h.Map is reallocated before the scan to clear existing values. If the
|
||||
// hstore column's database value is NULL, then h.Map is set to nil instead.
|
||||
func (h *Hstore) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
h.Map = nil
|
||||
return nil
|
||||
}
|
||||
h.Map = make(map[string]sql.NullString)
|
||||
var b byte
|
||||
pair := [][]byte{{}, {}}
|
||||
pi := 0
|
||||
inQuote := false
|
||||
didQuote := false
|
||||
sawSlash := false
|
||||
bindex := 0
|
||||
for bindex, b = range value.([]byte) {
|
||||
if sawSlash {
|
||||
pair[pi] = append(pair[pi], b)
|
||||
sawSlash = false
|
||||
continue
|
||||
}
|
||||
|
||||
switch b {
|
||||
case '\\':
|
||||
sawSlash = true
|
||||
continue
|
||||
case '"':
|
||||
inQuote = !inQuote
|
||||
if !didQuote {
|
||||
didQuote = true
|
||||
}
|
||||
continue
|
||||
default:
|
||||
if !inQuote {
|
||||
switch b {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
continue
|
||||
case '=':
|
||||
continue
|
||||
case '>':
|
||||
pi = 1
|
||||
didQuote = false
|
||||
continue
|
||||
case ',':
|
||||
s := string(pair[1])
|
||||
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
|
||||
h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false}
|
||||
} else {
|
||||
h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
|
||||
}
|
||||
pair[0] = []byte{}
|
||||
pair[1] = []byte{}
|
||||
pi = 0
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
pair[pi] = append(pair[pi], b)
|
||||
}
|
||||
if bindex > 0 {
|
||||
s := string(pair[1])
|
||||
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
|
||||
h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false}
|
||||
} else {
|
||||
h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface. Note if h.Map is nil, the
|
||||
// database column value will be set to NULL.
|
||||
func (h Hstore) Value() (driver.Value, error) {
|
||||
if h.Map == nil {
|
||||
return nil, nil
|
||||
}
|
||||
parts := []string{}
|
||||
for key, val := range h.Map {
|
||||
thispart := hQuote(key) + "=>" + hQuote(val)
|
||||
parts = append(parts, thispart)
|
||||
}
|
||||
return []byte(strings.Join(parts, ",")), nil
|
||||
}
|
||||
148
third_party/highgo-pq/hstore/hstore_test.go
vendored
Normal file
148
third_party/highgo-pq/hstore/hstore_test.go
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
package hstore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type Fatalistic interface {
|
||||
Fatal(args ...interface{})
|
||||
}
|
||||
|
||||
func openTestConn(t Fatalistic) *sql.DB {
|
||||
datname := os.Getenv("PGDATABASE")
|
||||
sslmode := os.Getenv("PGSSLMODE")
|
||||
|
||||
if datname == "" {
|
||||
os.Setenv("PGDATABASE", "pqgotest")
|
||||
}
|
||||
|
||||
if sslmode == "" {
|
||||
os.Setenv("PGSSLMODE", "disable")
|
||||
}
|
||||
|
||||
conn, err := sql.Open("postgres", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestHstore(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
// quietly create hstore if it doesn't exist
|
||||
_, err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore")
|
||||
if err != nil {
|
||||
t.Skipf("Skipping hstore tests - hstore extension create failed: %s", err.Error())
|
||||
}
|
||||
|
||||
hs := Hstore{}
|
||||
|
||||
// test for null-valued hstores
|
||||
err = db.QueryRow("SELECT NULL::hstore").Scan(&hs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hs.Map != nil {
|
||||
t.Fatalf("expected null map")
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs)
|
||||
if err != nil {
|
||||
t.Fatalf("re-query null map failed: %s", err.Error())
|
||||
}
|
||||
if hs.Map != nil {
|
||||
t.Fatalf("expected null map")
|
||||
}
|
||||
|
||||
// test for empty hstores
|
||||
err = db.QueryRow("SELECT ''::hstore").Scan(&hs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hs.Map == nil {
|
||||
t.Fatalf("expected empty map, got null map")
|
||||
}
|
||||
if len(hs.Map) != 0 {
|
||||
t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map))
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs)
|
||||
if err != nil {
|
||||
t.Fatalf("re-query empty map failed: %s", err.Error())
|
||||
}
|
||||
if hs.Map == nil {
|
||||
t.Fatalf("expected empty map, got null map")
|
||||
}
|
||||
if len(hs.Map) != 0 {
|
||||
t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map))
|
||||
}
|
||||
|
||||
// a few example maps to test out
|
||||
hsOnePair := Hstore{
|
||||
Map: map[string]sql.NullString{
|
||||
"key1": {String: "value1", Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
hsThreePairs := Hstore{
|
||||
Map: map[string]sql.NullString{
|
||||
"key1": {String: "value1", Valid: true},
|
||||
"key2": {String: "value2", Valid: true},
|
||||
"key3": {String: "value3", Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
hsSmorgasbord := Hstore{
|
||||
Map: map[string]sql.NullString{
|
||||
"nullstring": {String: "NULL", Valid: true},
|
||||
"actuallynull": {String: "", Valid: false},
|
||||
"NULL": {String: "NULL string key", Valid: true},
|
||||
"withbracket": {String: "value>42", Valid: true},
|
||||
"withequal": {String: "value=42", Valid: true},
|
||||
`"withquotes1"`: {String: `this "should" be fine`, Valid: true},
|
||||
`"withquotes"2"`: {String: `this "should\" also be fine`, Valid: true},
|
||||
"embedded1": {String: "value1=>x1", Valid: true},
|
||||
"embedded2": {String: `"value2"=>x2`, Valid: true},
|
||||
"withnewlines": {String: "\n\nvalue\t=>2", Valid: true},
|
||||
"<<all sorts of crazy>>": {String: `this, "should,\" also, => be fine`, Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
// test encoding in query params, then decoding during Scan
|
||||
testBidirectional := func(h Hstore) {
|
||||
err = db.QueryRow("SELECT $1::hstore", h).Scan(&hs)
|
||||
if err != nil {
|
||||
t.Fatalf("re-query %d-pair map failed: %s", len(h.Map), err.Error())
|
||||
}
|
||||
if hs.Map == nil {
|
||||
t.Fatalf("expected %d-pair map, got null map", len(h.Map))
|
||||
}
|
||||
if len(hs.Map) != len(h.Map) {
|
||||
t.Fatalf("expected %d-pair map, got len(map)=%d", len(h.Map), len(hs.Map))
|
||||
}
|
||||
|
||||
for key, val := range hs.Map {
|
||||
otherval, found := h.Map[key]
|
||||
if !found {
|
||||
t.Fatalf(" key '%v' not found in %d-pair map", key, len(h.Map))
|
||||
}
|
||||
if otherval.Valid != val.Valid {
|
||||
t.Fatalf(" value %v <> %v in %d-pair map", otherval, val, len(h.Map))
|
||||
}
|
||||
if otherval.String != val.String {
|
||||
t.Fatalf(" value '%v' <> '%v' in %d-pair map", otherval.String, val.String, len(h.Map))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testBidirectional(hsOnePair)
|
||||
testBidirectional(hsThreePairs)
|
||||
testBidirectional(hsSmorgasbord)
|
||||
}
|
||||
158
third_party/highgo-pq/issues_test.go
vendored
Normal file
158
third_party/highgo-pq/issues_test.go
vendored
Normal file
@@ -0,0 +1,158 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIssue494(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
query := `CREATE TEMP TABLE t (i INT PRIMARY KEY)`
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := txn.Prepare(CopyIn("t", "i")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := txn.Query("SELECT 1"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssue1046(t *testing.T) {
|
||||
ctxTimeout := time.Second * 2
|
||||
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
|
||||
defer cancel()
|
||||
|
||||
stmt, err := db.PrepareContext(ctx, `SELECT pg_sleep(10) AS id`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var d []uint8
|
||||
err = stmt.QueryRowContext(ctx).Scan(&d)
|
||||
dl, _ := ctx.Deadline()
|
||||
since := time.Since(dl)
|
||||
if since > ctxTimeout {
|
||||
t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since)
|
||||
t.Fail()
|
||||
}
|
||||
if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
|
||||
t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err())
|
||||
t.Logf("got err: [%T] %+v expected errCode: %v", err, err, cancelErrorCode)
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssue1062(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
// Ensure that cancelling a QueryRowContext does not result in an ErrBadConn.
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go cancel()
|
||||
row := db.QueryRowContext(ctx, "select 1")
|
||||
|
||||
var v int
|
||||
err := row.Scan(&v)
|
||||
if pgErr := (*Error)(nil); err != nil &&
|
||||
err != context.Canceled &&
|
||||
!(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
|
||||
t.Fatalf("Scan resulted in unexpected error %v for canceled QueryRowContext at attempt %d", err, i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func connIsValid(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// the connection must be valid
|
||||
err = conn.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("PingContext err=%#v", err)
|
||||
}
|
||||
// close must not return an error
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close err=%#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryCancelRace(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
// cancel a query while executing on Postgres: must return the cancelled error code
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
row := db.QueryRowContext(ctx, "select pg_sleep(0.5)")
|
||||
var pgSleepVoid string
|
||||
err := row.Scan(&pgSleepVoid)
|
||||
if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) {
|
||||
t.Fatalf("expected cancelled error; err=%#v", err)
|
||||
}
|
||||
|
||||
// get a connection: it must be a valid
|
||||
connIsValid(t, db)
|
||||
}
|
||||
|
||||
// Test cancelling a scan after it is started. This broke with 1.10.4.
|
||||
func TestQueryCancelledReused(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// run a query that returns a lot of data
|
||||
rows, err := db.QueryContext(ctx, "select generate_series(1, 10000)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// scan the first value
|
||||
if !rows.Next() {
|
||||
t.Error("expected rows.Next() to return true")
|
||||
}
|
||||
var i int
|
||||
err = rows.Scan(&i)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Error(i)
|
||||
}
|
||||
|
||||
// cancel the context and close rows, ignoring errors
|
||||
cancel()
|
||||
rows.Close()
|
||||
|
||||
// get a connection: it must be valid
|
||||
connIsValid(t, db)
|
||||
}
|
||||
27
third_party/highgo-pq/krb.go
vendored
Normal file
27
third_party/highgo-pq/krb.go
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
package pq
|
||||
|
||||
// NewGSSFunc creates a GSS authentication provider, for use with
|
||||
// RegisterGSSProvider.
|
||||
type NewGSSFunc func() (GSS, error)
|
||||
|
||||
var newGss NewGSSFunc
|
||||
|
||||
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
||||
// you need to use Kerberos to authenticate with your server, add this to your
|
||||
// main package:
|
||||
//
|
||||
// import "github.com/lib/pq/auth/kerberos"
|
||||
//
|
||||
// func init() {
|
||||
// pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() })
|
||||
// }
|
||||
func RegisterGSSProvider(newGssArg NewGSSFunc) {
|
||||
newGss = newGssArg
|
||||
}
|
||||
|
||||
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
||||
type GSS interface {
|
||||
GetInitToken(host string, service string) ([]byte, error)
|
||||
GetInitTokenFromSpn(spn string) ([]byte, error)
|
||||
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
||||
}
|
||||
72
third_party/highgo-pq/notice.go
vendored
Normal file
72
third_party/highgo-pq/notice.go
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build go1.10
|
||||
// +build go1.10
|
||||
|
||||
package pq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// NoticeHandler returns the notice handler on the given connection, if any. A
|
||||
// runtime panic occurs if c is not a pq connection. This is rarely used
|
||||
// directly, use ConnectorNoticeHandler and ConnectorWithNoticeHandler instead.
|
||||
func NoticeHandler(c driver.Conn) func(*Error) {
|
||||
return c.(*conn).noticeHandler
|
||||
}
|
||||
|
||||
// SetNoticeHandler sets the given notice handler on the given connection. A
|
||||
// runtime panic occurs if c is not a pq connection. A nil handler may be used
|
||||
// to unset it. This is rarely used directly, use ConnectorNoticeHandler and
|
||||
// ConnectorWithNoticeHandler instead.
|
||||
//
|
||||
// Note: Notice handlers are executed synchronously by pq meaning commands
|
||||
// won't continue to be processed until the handler returns.
|
||||
func SetNoticeHandler(c driver.Conn, handler func(*Error)) {
|
||||
c.(*conn).noticeHandler = handler
|
||||
}
|
||||
|
||||
// NoticeHandlerConnector wraps a regular connector and sets a notice handler
|
||||
// on it.
|
||||
type NoticeHandlerConnector struct {
|
||||
driver.Connector
|
||||
noticeHandler func(*Error)
|
||||
}
|
||||
|
||||
// Connect calls the underlying connector's connect method and then sets the
|
||||
// notice handler.
|
||||
func (n *NoticeHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
c, err := n.Connector.Connect(ctx)
|
||||
if err == nil {
|
||||
SetNoticeHandler(c, n.noticeHandler)
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// ConnectorNoticeHandler returns the currently set notice handler, if any. If
|
||||
// the given connector is not a result of ConnectorWithNoticeHandler, nil is
|
||||
// returned.
|
||||
func ConnectorNoticeHandler(c driver.Connector) func(*Error) {
|
||||
if c, ok := c.(*NoticeHandlerConnector); ok {
|
||||
return c.noticeHandler
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectorWithNoticeHandler creates or sets the given handler for the given
|
||||
// connector. If the given connector is a result of calling this function
|
||||
// previously, it is simply set on the given connector and returned. Otherwise,
|
||||
// this returns a new connector wrapping the given one and setting the notice
|
||||
// handler. A nil notice handler may be used to unset it.
|
||||
//
|
||||
// The returned connector is intended to be used with database/sql.OpenDB.
|
||||
//
|
||||
// Note: Notice handlers are executed synchronously by pq meaning commands
|
||||
// won't continue to be processed until the handler returns.
|
||||
func ConnectorWithNoticeHandler(c driver.Connector, handler func(*Error)) *NoticeHandlerConnector {
|
||||
if c, ok := c.(*NoticeHandlerConnector); ok {
|
||||
c.noticeHandler = handler
|
||||
return c
|
||||
}
|
||||
return &NoticeHandlerConnector{Connector: c, noticeHandler: handler}
|
||||
}
|
||||
34
third_party/highgo-pq/notice_example_test.go
vendored
Normal file
34
third_party/highgo-pq/notice_example_test.go
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build go1.10
|
||||
// +build go1.10
|
||||
|
||||
package pq_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func ExampleConnectorWithNoticeHandler() {
|
||||
name := ""
|
||||
// Base connector to wrap
|
||||
base, err := pq.NewConnector(name)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Wrap the connector to simply print out the message
|
||||
connector := pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) {
|
||||
fmt.Println("Notice sent: " + notice.Message)
|
||||
})
|
||||
db := sql.OpenDB(connector)
|
||||
defer db.Close()
|
||||
// Raise a notice
|
||||
sql := "DO language plpgsql $$ BEGIN RAISE NOTICE 'test notice'; END $$"
|
||||
if _, err := db.Exec(sql); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// Notice sent: test notice
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user