Compare commits

...

24 Commits

Author SHA1 Message Date
Syngnat
71e5de0cdc ♻️ refactor(database/ssh): SSH隧道架构重构与多数据源适配
- 架构升级:从driver专属拨号器改为通用本地端口转发模式
  - 并发安全:sync.Once保护Close操作,RWMutex保护状态访问,双向errc等待
  - 连接池化:GetOrCreateLocalForwarder/GetOrCreateSSHClient实现缓存复用
  - SQL安全:kingbase_impl.go引入esc函数,防止双引号注入(""ldf_server""问题)
  - Schema动态化:三级fallback(schema.table解析→dbName参数→current_schema())
  - 代码复用:scanRows统一行扫描逻辑,normalizeQueryValueWithDBType增强类型处理
  Close #40
2026-02-04 14:35:31 +08:00
Syngnat
d8656c6c9c 🐛 fix(query-editor): 修复别名字段不联想与启动编译报错
- a.<field> 场景根据 alias->table 提供字段补全
  - 修复 currentDbRef 重复声明(TS2451)
  - 保持原关键字/表名/字段补全行为不变
2026-02-04 12:37:30 +08:00
Syngnat
443b487a02 Merge pull request #60 from Syngnat/feature/0.2.5
Feature/0.2.5
2026-02-04 12:31:50 +08:00
Syngnat
bac57ebdf0 Merge pull request #59 from Syngnat/dev
🐛 fix(table): 修复虚拟表全选丢失并完善导出/筛选能力

- 表头自定义组件保留 width,virtual 模式下选择列正常显示
- 新增后端 ExportQuery,导出当前页/选中行避免长字段 IPC 截断
- 筛选支持更多操作符并统一 WHERE 生成逻辑
Close #57
Close #56

 feat(table-edit): 增加整行编辑面板,提升多字段/长文本编辑效率

- 支持选中行后一键打开编辑面板
- 全字段可编辑,长文本/JSON 友好输入与弹窗编辑
- 应用后写入本地变更,提交事务后落库

️ perf(table): 表数据打开加速,主键/统计等耗时操作异步化

- DataViewer 主键列元数据异步拉取,首屏数据优先渲染
- 查询页增加结果集最大行数限制,减少大表全量返回
- DBQuery 引入 Context 超时,降低长查询对 UI 的阻塞风险
- 查询行数设置持久化保存
Closes #48 

 feat(db-ui): 修复金仓打开表报错并增强结果页编辑体验

- postgres/kingbase 查询前自动清洗 ""ident"" 形式的非法标识符
- 结果表支持单元格弹窗编辑,提升 JSON/长文本可编辑性
- 修复查询结果表头与数据列宽度不对齐问题
Closes #49
2026-02-04 12:30:42 +08:00
Syngnat
213a33e4f3 Merge pull request #58 from Syngnat/feature/table-and-database-export-20260203-ygf
Feature/table and database export 20260203 ygf
2026-02-04 12:29:33 +08:00
Syngnat
a00f87582d 🐛 fix(table): 修复虚拟表全选丢失并完善导出/筛选能力
- 表头自定义组件保留 width,virtual 模式下选择列正常显示
  - 新增后端 ExportQuery,导出当前页/选中行避免长字段 IPC 截断
  - 筛选支持更多操作符并统一 WHERE 生成逻辑
  Close #57
  Close #56
2026-02-04 12:23:41 +08:00
Syngnat
f129623000 feat(table-edit): 增加整行编辑面板,提升多字段/长文本编辑效率
- 支持选中行后一键打开编辑面板
  - 全字段可编辑,长文本/JSON 友好输入与弹窗编辑
  - 应用后写入本地变更,提交事务后落库
2026-02-04 11:43:47 +08:00
Syngnat
8dbc97e466 ️ perf(table): 表数据打开加速,主键/统计等耗时操作异步化
- DataViewer 主键列元数据异步拉取,首屏数据优先渲染
  - 查询页增加结果集最大行数限制,减少大表全量返回
  - DBQuery 引入 Context 超时,降低长查询对 UI 的阻塞风险
  - 查询行数设置持久化保存
  Closes #48
  Closes #49
2026-02-04 11:01:28 +08:00
Syngnat
4a0db185c0 feat(db-ui): 修复金仓打开表报错并增强结果页编辑体验
- postgres/kingbase 查询前自动清洗 ""ident"" 形式的非法标识符
  - 结果表支持单元格弹窗编辑,提升 JSON/长文本可编辑性
  - 修复查询结果表头与数据列宽度不对齐问题
2026-02-04 10:13:02 +08:00
Syngnat
5793f63ac8 ️ optimize(core): 查询多语句多结果与大表交互/元数据体验优化
- 支持分号多语句拆分(含引号/注释/PG dollar-quote),多结果集 Tab 展示;
- 支持选中运行;结果 Tab 支持关闭
- 修复结果区高度自动收缩/最后一行裁剪;切换结果更顺滑(关闭 ink-bar 动画、修复隐藏面板叠加显示)
- 补齐 PostgreSQL/SQLite 设计表元数据接口;
- 修复 Kingbase schema/标识符引用导致打开表失败
- 标签页右键支持关闭其他/关闭左侧/关闭右侧/关闭所有
2026-02-03 22:48:24 +08:00
Syngnat
8aabc67634 Merge pull request #46 from Syngnat/feature/table-and-database-export-20260203-ygf
- 支持分号多语句拆分(含引号/注释/PG dollar-quote),多结果集 Tab 展示;
- 支持选中运行;结果 Tab 支持关闭
- 修复结果区高度自动收缩/最后一行裁剪;切换结果更顺滑(关闭 ink-bar 动画、修复隐藏面板叠加显示)
- 补齐 PostgreSQL/SQLite 设计表元数据接口;
- 修复 Kingbase schema/标识符引用导致打开表失败
- 标签页右键支持关闭其他/关闭左侧/关闭右侧/关闭所有
2026-02-03 22:46:25 +08:00
杨国锋
34c494ce51 ️ optimize(core): 查询多语句多结果与大表交互/元数据体验优化
- 支持分号多语句拆分(含引号/注释/PG dollar-quote),多结果集 Tab 展示;
  - 支持选中运行;结果 Tab 支持关闭
  - 修复结果区高度自动收缩/最后一行裁剪;切换结果更顺滑(关闭 ink-bar 动画、修复隐藏面板叠加显示)
  - 补齐 PostgreSQL/SQLite 设计表元数据接口;
  - 修复 Kingbase schema/标识符引用导致打开表失败
  - 标签页右键支持关闭其他/关闭左侧/关闭右侧/关闭所有
2026-02-03 22:44:48 +08:00
Syngnat
178de02783 Merge pull request #45 from bengbengbalabalabeng/chore-add-issues-templates
- 新增 issues template 以统一 issue 类型
2026-02-03 22:39:39 +08:00
baicaixiaozhan
94e5b8d2c6 chore: add Github issues templates 2026-02-03 21:49:43 +08:00
杨国锋
89e2247c05 feat(database): 增强库/表级导出与备份能力,优化侧边栏交互
- 数据库节点新增导出全部表结构/结构+数据 SQL(ExportDatabaseSQL)
  - 表节点支持多选/单选右键导出与备份(ExportTablesSQL)
  - ExportTable 支持导出 SQL(结构+数据)
  - 双击表仅打开表数据,不再触发展开/折叠
2026-02-03 19:49:04 +08:00
Syngnat
b2ede61b79 Merge pull request #43 from Syngnat/feature/0.2.3
️ perf(frontend): 大数据表格拖拽与打开加载性能、增加数据同步差异对比、行级选择
2026-02-03 19:23:49 +08:00
Syngnat
db381ae9d1 Merge pull request #42 from Syngnat/dev
️ perf(frontend): 大数据表格拖拽与打开加载性能、增加数据同步差异对比、行级选择
2026-02-03 19:23:15 +08:00
Syngnat
f946cfd647 Merge pull request #41 from Syngnat/feature/data-sync-optimization-20260203-ygf
️ perf(frontend): 大数据表格拖拽与打开加载性能、增加数据同步差异对比、行级选择
2026-02-03 19:21:29 +08:00
Syngnat
791425a5a8 🐛 fix(db): 适配 schema/owner 限定名,修复 PG/金仓表不存在,修复表格数据显示异常
- 覆盖 mysql/postgres/kingbase/oracle/dameng/sqlite/custom 的 Query 返回值转换
- 修正可编辑表格保存范围,避免状态残留影响显示
- 表列表返回 schema.table/owner.table,避免 search_path 不一致导致 relation does not exist
- 元数据/导入导出/提交变更统一解析限定名并正确引用
- 前端查询与数据浏览支持限定名 quote
- 单元格编辑态时间字段统一显示为 YYYY-MM-DD HH:mm:ss
close #36
2026-02-03 14:39:05 +08:00
Syngnat
d7acfd1af9 🐛 fix(db): 适配 schema/owner 限定名,修复 PG/金仓表不存在,修复表格数据显示异常
- 覆盖 mysql/postgres/kingbase/oracle/dameng/sqlite/custom 的 Query 返回值转换
- 修正可编辑表格保存范围,避免状态残留影响显示
- 表列表返回 schema.table/owner.table,避免 search_path 不一致导致 relation does not exist
- 元数据/导入导出/提交变更统一解析限定名并正确引用
- 前端查询与数据浏览支持限定名 quote
- 单元格编辑态时间字段统一显示为 YYYY-MM-DD HH:mm:ss
close #36
2026-02-03 14:38:05 +08:00
Syngnat
88952e87c1 Merge pull request #35 from Syngnat/release/0.2.1
- 前端改用通用 DB API,避免强制走 MySQL 接口导致 PostgreSQL 等连接异常
- 后端统一各数据源 timeout(Ping 超时 + 连接参数注入)
- DSN 生成兼容特殊字符密码(Postgres/Oracle/达梦/金仓)
- 增加文件日志与错误链输出,连接失败提示日志路径便于排障
2026-02-03 12:27:39 +08:00
Syngnat
c981a65834 Merge pull request #32 from Syngnat/release/0.2.0
Release/0.2.0
2026-02-03 08:53:42 +08:00
Syngnat
b9d9ab5464 Merge pull request #31 from Syngnat/dev
docs: improve image layout
2026-02-03 08:53:14 +08:00
Syngnat
6b503480cf fix: badges display error in readme
- 修复README中徽章显示问题
2026-02-02 21:14:33 +08:00
33 changed files with 4435 additions and 598 deletions

View File

@@ -0,0 +1,58 @@
name: 问题反馈
description: 软件问题反馈
title: "[Bug] "
labels: ["bug"]
body:
- type: checkboxes
id: searched
attributes:
label: 已经搜索过 Issues未发现重复问题*
options:
- label: 我已经搜索过 Issues没有发现重复问题
validations:
required: true
- type: input
id: system
attributes:
label: 操作系统及版本
placeholder: Windows 10 22H2 / macOS Mojave / Linux
validations:
required: true
- type: input
id: version
attributes:
label: 软件安装版本
placeholder: v0.2.3
validations:
required: true
- type: textarea
id: description
attributes:
label: 问题简述及复现流程
description: 请详细描述你遇到的问题,并提供复现步骤
placeholder: |
1. 打开软件
2. 点击 xxx
3. 预期结果是 ...
4. 实际结果是 ...
5. 截图 ...
validations:
required: true
- type: textarea
id: extra
attributes:
label: 其他补充
description: 如果你有额外信息,请在此填写
placeholder: 可选
- type: checkboxes
id: pr
attributes:
label: 是否愿意提交 PR 修复当前 Issue
options:
- label: 我愿意尝试提交 PR

View File

@@ -0,0 +1,37 @@
name: 功能建议
description: 添加全新功能或改进现有功能
title: "[Enhancement] "
labels: ["enhancement"]
body:
- type: checkboxes
id: searched
attributes:
label: 已经搜索过 Issues未发现重复问题*
options:
- label: 我已经搜索过 Issues没有发现重复问题
validations:
required: true
- type: textarea
id: feature
attributes:
label: 功能描述
description: 请详细描述你希望添加或改进的功能
placeholder: 请描述你想要的功能
validations:
required: true
- type: textarea
id: extra
attributes:
label: 其他补充
description: 如果你有额外信息,请在此填写
placeholder: 可选
- type: checkboxes
id: pr
attributes:
label: 是否愿意提交 PR 实现当前 Issue
options:
- label: 我愿意尝试提交 PR

30
.github/ISSUE_TEMPLATE/03-generic.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: 其他反馈
description: 其他类型反馈、建议或讨论
title: "[Question] "
labels: ["question"]
body:
- type: checkboxes
id: searched
attributes:
label: 已经搜索过 Issues未发现重复问题*
options:
- label: 我已经搜索过 Issues没有发现重复问题
validations:
required: true
- type: textarea
id: content
attributes:
label: 内容
description: 请填写你的反馈、建议或讨论内容
placeholder: 请描述你的问题或想法
validations:
required: true
- type: textarea
id: extra
attributes:
label: 其他补充
description: 如果你有额外信息,请在此填写
placeholder: 可选

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1 @@
blank_issues_enabled: false

View File

@@ -1 +1 @@
d0f9366af59a6367ad3c7e2d4185ead4
5b8157374dae5f9340e31b2d0bd2c00e

View File

@@ -285,12 +285,12 @@ function App() {
title="拖动调整宽度"
/>
</Sider>
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<div style={{ flex: 1, overflow: 'hidden' }}>
<TabManager />
</div>
{isLogPanelOpen && (
<LogPanel
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<TabManager />
</div>
{isLogPanelOpen && (
<LogPanel
height={logPanelHeight}
onClose={() => setIsLogPanelOpen(false)}
onResizeStart={handleLogResizeStart}
@@ -343,4 +343,4 @@ function App() {
);
}
export default App;
export default App;

View File

@@ -264,8 +264,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
{useSSH && (
<div style={{ padding: '12px', background: '#f5f5f5', borderRadius: 6, marginTop: 12 }}>
<div style={{ display: 'flex', gap: 16 }}>
<Form.Item name="sshHost" label="SSH 主机" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
<Input placeholder="ssh.example.com" />
<Form.Item name="sshHost" label="SSH 主机 (域名或IP)" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
<Input placeholder="例如: ssh.example.com 或 192.168.1.100" />
</Form.Item>
<Form.Item name="sshPort" label="端口" rules={[{ required: useSSH, message: '请输入SSH端口' }]} style={{ width: 100 }}>
<InputNumber style={{ width: '100%' }} />

View File

@@ -1,11 +1,13 @@
import React, { useState, useEffect, useRef, useContext, useMemo, useCallback } from 'react';
import { Table, message, Input, Button, Dropdown, MenuProps, Form, Pagination, Select, Modal } from 'antd';
import type { SortOrder } from 'antd/es/table/interface';
import { ReloadOutlined, ImportOutlined, ExportOutlined, DownOutlined, PlusOutlined, DeleteOutlined, SaveOutlined, UndoOutlined, FilterOutlined, CloseOutlined, ConsoleSqlOutlined, FileTextOutlined, CopyOutlined, ClearOutlined } from '@ant-design/icons';
import { ImportData, ExportTable, ExportData, ApplyChanges } from '../../wailsjs/go/app/App';
import { ReloadOutlined, ImportOutlined, ExportOutlined, DownOutlined, PlusOutlined, DeleteOutlined, SaveOutlined, UndoOutlined, FilterOutlined, CloseOutlined, ConsoleSqlOutlined, FileTextOutlined, CopyOutlined, ClearOutlined, EditOutlined } from '@ant-design/icons';
import Editor from '@monaco-editor/react';
import { ImportData, ExportTable, ExportData, ExportQuery, ApplyChanges } from '../../wailsjs/go/app/App';
import { useStore } from '../store';
import { v4 as uuidv4 } from 'uuid';
import 'react-resizable/css/styles.css';
import { buildWhereSQL, escapeLiteral, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
// 内部行标识字段:避免与真实业务字段(如 `key` 列)冲突。
export const GONAVI_ROW_KEY = '__gonavi_row_key__';
@@ -27,16 +29,47 @@ const formatCellValue = (val: any) => {
return String(val);
};
const toEditableText = (val: any): string => {
if (val === null || val === undefined) return '';
if (typeof val === 'string') return val;
try {
return JSON.stringify(val, null, 2);
} catch {
return String(val);
}
};
const toFormText = (val: any): string => {
if (val === null || val === undefined) return '';
if (typeof val === 'string') return normalizeDateTimeString(val);
return toEditableText(val);
};
const looksLikeJsonText = (text: string): boolean => {
const raw = (text || '').trim();
if (!raw) return false;
const first = raw[0];
const last = raw[raw.length - 1];
return (first === '{' && last === '}') || (first === '[' && last === ']');
};
// --- Resizable Header (Native Implementation) ---
const ResizableTitle = (props: any) => {
const { onResizeStart, width, ...restProps } = props;
if (!width) {
return <th {...restProps} />;
const nextStyle = { ...(restProps.style || {}) } as React.CSSProperties;
if (width) {
nextStyle.width = width;
}
// 注意virtual table 模式下rc-table 会依赖 header cell 的 width 样式来渲染选择列。
// 若这里丢失 width可能导致左上角“全选”checkbox 不显示。
if (!width || typeof onResizeStart !== 'function') {
return <th {...restProps} style={nextStyle} />;
}
return (
<th {...restProps} style={{ ...restProps.style, position: 'relative' }}>
<th {...restProps} style={{ ...nextStyle, position: 'relative' }}>
{restProps.children}
<span
className="react-resizable-handle"
@@ -85,6 +118,7 @@ interface EditableCellProps {
dataIndex: string;
record: Item;
handleSave: (record: Item) => void;
focusCell?: (record: Item, dataIndex: string, title: React.ReactNode) => void;
[key: string]: any;
}
@@ -95,6 +129,7 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
dataIndex,
record,
handleSave,
focusCell,
...restProps
}) => {
const [editing, setEditing] = useState(false);
@@ -139,7 +174,26 @@ const EditableCell: React.FC<EditableCellProps> = React.memo(({
);
}
return <td {...restProps} onDoubleClick={editable ? toggleEdit : undefined}>{childNode}</td>;
const handleDoubleClick = () => {
if (!editable) return;
toggleEdit();
};
const handleClick = (e: React.MouseEvent) => {
restProps?.onClick?.(e);
if (!editable) return;
if (typeof focusCell === 'function') focusCell(record, dataIndex, title);
};
return (
<td
{...restProps}
onClick={editable ? handleClick : restProps?.onClick}
onDoubleClick={editable ? handleDoubleClick : restProps?.onDoubleClick}
>
{childNode}
</td>
);
});
const ContextMenuRow = React.memo(({ children, record, ...props }: any) => {
@@ -221,9 +275,23 @@ const DataGrid: React.FC<DataGridProps> = ({
}) => {
const { connections } = useStore();
const addSqlLog = useStore(state => state.addSqlLog);
const darkMode = useStore(state => state.darkMode);
const selectionColumnWidth = 46;
const [form] = Form.useForm();
const [modal, contextHolder] = Modal.useModal();
const gridId = useMemo(() => `grid-${uuidv4()}`, []);
const [cellEditorOpen, setCellEditorOpen] = useState(false);
const [cellEditorValue, setCellEditorValue] = useState('');
const [cellEditorIsJson, setCellEditorIsJson] = useState(false);
const [cellEditorMeta, setCellEditorMeta] = useState<{ record: Item; dataIndex: string; title: string } | null>(null);
const cellEditorApplyRef = useRef<((val: string) => void) | null>(null);
const [activeCell, setActiveCell] = useState<{ rowKey: string; dataIndex: string; title: string } | null>(null);
const [rowEditorOpen, setRowEditorOpen] = useState(false);
const [rowEditorRowKey, setRowEditorRowKey] = useState<string>('');
const rowEditorBaseRef = useRef<Record<string, string>>({});
const rowEditorDisplayRef = useRef<Record<string, string>>({});
const rowEditorNullColsRef = useRef<Set<string>>(new Set());
const [rowEditorForm] = Form.useForm();
// Helper to export specific data
const exportData = async (rows: any[], format: string) => {
@@ -237,32 +305,66 @@ const DataGrid: React.FC<DataGridProps> = ({
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
const [columnWidths, setColumnWidths] = useState<Record<string, number>>({});
const closeCellEditor = useCallback(() => {
setCellEditorOpen(false);
setCellEditorMeta(null);
setCellEditorValue('');
setCellEditorIsJson(false);
cellEditorApplyRef.current = null;
}, []);
const openCellEditor = useCallback((record: Item, dataIndex: string, title: React.ReactNode, onApplyValue?: (val: string) => void) => {
if (!record || !dataIndex) return;
const raw = record?.[dataIndex];
const text = toEditableText(raw);
const isJson = looksLikeJsonText(text);
const titleText = typeof title === 'string' ? title : (typeof title === 'number' ? String(title) : String(dataIndex));
setCellEditorMeta({ record, dataIndex, title: titleText });
setCellEditorValue(text);
setCellEditorIsJson(isJson);
setCellEditorOpen(true);
cellEditorApplyRef.current = typeof onApplyValue === 'function' ? onApplyValue : null;
}, []);
// Dynamic Height
const [tableHeight, setTableHeight] = useState(500);
const containerRef = useRef<HTMLDivElement>(null);
useEffect(() => {
if (!containerRef.current) return;
let rafId: number;
const el = containerRef.current;
if (!el) return;
let rafId: number | null = null;
const resizeObserver = new ResizeObserver(entries => {
if (rafId !== null) cancelAnimationFrame(rafId);
rafId = requestAnimationFrame(() => {
for (let entry of entries) {
// Use boundingClientRect for more accurate render size (including padding if any)
const height = entry.contentRect.height;
if (height < 50) return;
// Subtract header (~42px) and a buffer
const h = Math.max(100, height - 42);
setTableHeight(h);
}
const target = (entries[0]?.target as HTMLElement | undefined) || containerRef.current;
if (!target) return;
const height = target.getBoundingClientRect().height;
if (!Number.isFinite(height) || height < 50) return;
const headerEl =
(target.querySelector('.ant-table-header') as HTMLElement | null) ||
(target.querySelector('.ant-table-thead') as HTMLElement | null);
const rawHeaderHeight = headerEl ? headerEl.getBoundingClientRect().height : NaN;
const headerHeight =
Number.isFinite(rawHeaderHeight) && rawHeaderHeight >= 24 && rawHeaderHeight <= 120 ? rawHeaderHeight : 42;
// 留一点余量,避免底部(边框/滚动条)遮挡最后一行
const extraBottom = 16;
const nextHeight = Math.max(100, Math.floor(height - headerHeight - extraBottom));
setTableHeight(nextHeight);
});
});
resizeObserver.observe(containerRef.current);
resizeObserver.observe(el);
return () => {
resizeObserver.disconnect();
cancelAnimationFrame(rafId);
if (rafId !== null) cancelAnimationFrame(rafId);
};
}, []);
@@ -272,7 +374,7 @@ const DataGrid: React.FC<DataGridProps> = ({
const [deletedRowKeys, setDeletedRowKeys] = useState<Set<string>>(new Set());
// Filter State
const [filterConditions, setFilterConditions] = useState<{ id: number, column: string, op: string, value: string }[]>([]);
const [filterConditions, setFilterConditions] = useState<{ id: number, column: string, op: string, value: string, value2?: string }[]>([]);
const [nextFilterId, setNextFilterId] = useState(1);
const selectedRowKeysRef = useRef(selectedRowKeys);
@@ -286,6 +388,14 @@ const DataGrid: React.FC<DataGridProps> = ({
setModifiedRows({});
setDeletedRowKeys(new Set());
setSelectedRowKeys([]);
setActiveCell(null);
setRowEditorOpen(false);
setRowEditorRowKey('');
rowEditorBaseRef.current = {};
rowEditorDisplayRef.current = {};
rowEditorNullColsRef.current = new Set();
rowEditorForm.resetFields();
closeCellEditor();
form.resetFields();
}, [tableName, dbName, connectionId]); // Reset on context change
@@ -440,6 +550,29 @@ const DataGrid: React.FC<DataGridProps> = ({
}
}, [addedRows]);
const handleCellEditorSave = useCallback(() => {
if (!cellEditorMeta) return;
const apply = cellEditorApplyRef.current;
if (apply) {
apply(cellEditorValue);
closeCellEditor();
return;
}
const nextRow: any = { ...cellEditorMeta.record, [cellEditorMeta.dataIndex]: cellEditorValue };
handleCellSave(nextRow);
closeCellEditor();
}, [cellEditorMeta, cellEditorValue, handleCellSave, closeCellEditor]);
const handleFormatJsonInEditor = useCallback(() => {
if (!cellEditorIsJson) return;
try {
const obj = JSON.parse(cellEditorValue);
setCellEditorValue(JSON.stringify(obj, null, 2));
} catch (e: any) {
message.error("JSON 格式无效:" + (e?.message || String(e)));
}
}, [cellEditorIsJson, cellEditorValue]);
// Merge Data for Display
// 'displayData' already merges addedRows.
// We need to merge modifiedRows into it for rendering.
@@ -453,6 +586,110 @@ const DataGrid: React.FC<DataGridProps> = ({
});
}, [displayData, modifiedRows]);
const focusCell = useCallback((record: Item, dataIndex: string, title: React.ReactNode) => {
const k = record?.[GONAVI_ROW_KEY];
if (k === undefined) return;
const titleText = typeof title === 'string' ? title : (typeof title === 'number' ? String(title) : String(dataIndex));
setActiveCell({ rowKey: rowKeyStr(k), dataIndex, title: titleText });
}, [rowKeyStr]);
const closeRowEditor = useCallback(() => {
setRowEditorOpen(false);
setRowEditorRowKey('');
rowEditorBaseRef.current = {};
rowEditorDisplayRef.current = {};
rowEditorNullColsRef.current = new Set();
rowEditorForm.resetFields();
}, [rowEditorForm]);
const openRowEditor = useCallback(() => {
if (readOnly || !tableName) return;
if (selectedRowKeys.length > 1) {
message.info('一次只能编辑一行,请仅选择一行');
return;
}
const keyStr =
selectedRowKeys.length === 1 ? rowKeyStr(selectedRowKeys[0]) : activeCell?.rowKey;
if (!keyStr) {
message.info('请先选择一行(勾选一行或点击任意单元格)');
return;
}
const displayRow = mergedDisplayData.find(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr);
if (!displayRow) {
message.error('未找到目标行,请刷新后重试');
return;
}
const baseRow =
data.find(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr) ||
addedRows.find(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr) ||
displayRow;
const baseMap: Record<string, string> = {};
const displayMap: Record<string, string> = {};
const nullCols = new Set<string>();
columnNames.forEach((col) => {
const baseVal = (baseRow as any)?.[col];
const displayVal = (displayRow as any)?.[col];
baseMap[col] = toFormText(baseVal);
displayMap[col] = toFormText(displayVal);
if (baseVal === null || baseVal === undefined) nullCols.add(col);
});
rowEditorBaseRef.current = baseMap;
rowEditorDisplayRef.current = displayMap;
rowEditorNullColsRef.current = nullCols;
rowEditorForm.setFieldsValue(displayMap);
setRowEditorRowKey(keyStr);
setRowEditorOpen(true);
}, [readOnly, tableName, selectedRowKeys, activeCell, mergedDisplayData, data, addedRows, columnNames, rowEditorForm, rowKeyStr]);
const openRowEditorFieldEditor = useCallback((dataIndex: string) => {
if (!dataIndex) return;
const val = rowEditorForm.getFieldValue(dataIndex);
openCellEditor(
{ [dataIndex]: val ?? '' },
dataIndex,
dataIndex,
(nextVal) => rowEditorForm.setFieldsValue({ [dataIndex]: nextVal }),
);
}, [rowEditorForm, openCellEditor]);
const applyRowEditor = useCallback(() => {
const keyStr = rowEditorRowKey;
if (!keyStr) return;
const values = rowEditorForm.getFieldsValue(true) || {};
const isAdded = addedRows.some(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr);
if (isAdded) {
setAddedRows(prev => prev.map(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr ? { ...r, ...values } : r));
closeRowEditor();
return;
}
const baseMap = rowEditorBaseRef.current || {};
const patch: Record<string, any> = {};
columnNames.forEach((col) => {
const nextVal = values[col];
const nextStr = toFormText(nextVal);
const baseStr = baseMap[col] ?? '';
if (nextStr !== baseStr) patch[col] = nextStr;
});
setModifiedRows(prev => {
const next = { ...prev };
if (Object.keys(patch).length === 0) delete next[keyStr];
else next[keyStr] = patch;
return next;
});
closeRowEditor();
}, [rowEditorRowKey, rowEditorForm, addedRows, columnNames, rowKeyStr, closeRowEditor]);
const columns = useMemo(() => {
return columnNames.map(key => ({
title: key,
@@ -481,9 +718,13 @@ const DataGrid: React.FC<DataGridProps> = ({
dataIndex: col.dataIndex,
title: col.title,
handleSave: handleCellSave,
focusCell,
className: (activeCell && rowKeyStr(record?.[GONAVI_ROW_KEY]) === activeCell.rowKey && col.dataIndex === activeCell.dataIndex)
? 'gonavi-active-cell'
: undefined,
}),
};
}), [columns, handleCellSave]);
}), [columns, handleCellSave, openCellEditor, focusCell, activeCell, rowKeyStr]);
const handleAddRow = () => {
const newKey = `new-${Date.now()}`;
@@ -633,11 +874,98 @@ const DataGrid: React.FC<DataGridProps> = ({
copyToClipboard(lines.join('\n'));
}, [getTargets, copyToClipboard]);
const buildConnConfig = useCallback(() => {
if (!connectionId) return null;
const conn = connections.find(c => c.id === connectionId);
if (!conn) return null;
return {
...conn.config,
port: Number(conn.config.port),
password: conn.config.password || "",
database: conn.config.database || "",
useSSH: conn.config.useSSH || false,
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
};
}, [connections, connectionId]);
const exportByQuery = useCallback(async (sql: string, format: string, defaultName: string) => {
const config = buildConnConfig();
if (!config) return;
const hide = message.loading(`正在导出...`, 0);
const res = await ExportQuery(config as any, dbName || '', sql, defaultName || 'export', format);
hide();
if (res.success) {
message.success("导出成功");
} else if (res.message !== "Cancelled") {
message.error("导出失败: " + res.message);
}
}, [buildConnConfig, dbName]);
const buildPkWhereSql = useCallback((rows: any[], dbType: string) => {
if (!tableName || pkColumns.length === 0) return '';
const targets = (rows || []).filter(Boolean);
if (targets.length === 0) return '';
const clauses: string[] = [];
for (const r of targets) {
const andParts: string[] = [];
for (const pk of pkColumns) {
const col = quoteIdentPart(dbType, pk);
const v = r?.[pk];
if (v === null || v === undefined) return '';
andParts.push(`${col} = '${escapeLiteral(String(v))}'`);
}
if (andParts.length === pkColumns.length) {
clauses.push(`(${andParts.join(' AND ')})`);
}
}
if (clauses.length === 0) return '';
return clauses.join(' OR ');
}, [pkColumns, tableName]);
const buildCurrentPageSql = useCallback((dbType: string) => {
if (!tableName || !pagination) return '';
const whereSQL = buildWhereSQL(dbType, filterConditions);
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
if (sortInfo && sortInfo.order) {
sql += ` ORDER BY ${quoteIdentPart(dbType, sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
}
const offset = (pagination.current - 1) * pagination.pageSize;
sql += ` LIMIT ${pagination.pageSize} OFFSET ${offset}`;
return sql;
}, [tableName, pagination, filterConditions, sortInfo]);
// Context Menu Export
const handleExportSelected = useCallback(async (format: string, record: any) => {
const records = getTargets(record);
await exportData(records, format);
}, [getTargets]);
if (!connectionId || !tableName) {
await exportData(records, format);
return;
}
// 有未提交修改时,优先按界面数据导出,避免与数据库不一致。
if (hasChanges) {
message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。");
await exportData(records, format);
return;
}
const config = buildConnConfig();
if (!config) {
await exportData(records, format);
return;
}
const dbType = config.type || '';
const pkWhere = buildPkWhereSql(records, dbType);
if (!pkWhere) {
await exportData(records, format);
return;
}
const sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} WHERE ${pkWhere}`;
await exportByQuery(sql, format, tableName || 'export');
}, [getTargets, connectionId, tableName, hasChanges, exportData, buildConnConfig, buildPkWhereSql, exportByQuery]);
// Export
const handleExport = async (format: string) => {
@@ -646,7 +974,7 @@ const DataGrid: React.FC<DataGridProps> = ({
// 1. Export Selected
if (selectedRowKeys.length > 0) {
const selectedRows = displayData.filter(d => selectedRowKeys.includes(d?.[GONAVI_ROW_KEY]));
await exportData(selectedRows, format);
await handleExportSelected(format, selectedRows[0]);
return;
}
@@ -655,9 +983,8 @@ const DataGrid: React.FC<DataGridProps> = ({
let instance: any;
const handleAll = async () => {
instance.destroy();
const conn = connections.find(c => c.id === connectionId);
if (!conn) return;
const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } };
const config = buildConnConfig();
if (!config) return;
const hide = message.loading(`正在导出全部数据...`, 0);
const res = await ExportTable(config as any, dbName || '', tableName, format);
hide();
@@ -665,7 +992,25 @@ const DataGrid: React.FC<DataGridProps> = ({
};
const handlePage = async () => {
instance.destroy();
await exportData(displayData, format);
if (hasChanges) {
message.warning("当前存在未提交修改,导出将按界面数据生成;如需完整长字段建议先提交后再导出。");
await exportData(displayData, format);
return;
}
const config = buildConnConfig();
if (!config) {
await exportData(displayData, format);
return;
}
const sql = buildCurrentPageSql(config.type || '');
if (!sql) {
await exportData(displayData, format);
return;
}
await exportByQuery(sql, format, tableName || 'export');
};
instance = modal.info({
@@ -688,21 +1033,64 @@ const DataGrid: React.FC<DataGridProps> = ({
const handleImport = async () => {
if (!connectionId || !tableName) return;
const conn = connections.find(c => c.id === connectionId);
if (!conn) return;
const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } };
const config = buildConnConfig();
if (!config) return;
const res = await ImportData(config as any, dbName || '', tableName);
if (res.success) { message.success(res.message); if (onReload) onReload(); } else if (res.message !== "Cancelled") { message.error("Import Failed: " + res.message); }
};
// Filters
const filterOpOptions = useMemo(() => ([
{ value: '=', label: '=' },
{ value: '!=', label: '!=' },
{ value: '<', label: '<' },
{ value: '<=', label: '<=' },
{ value: '>', label: '>' },
{ value: '>=', label: '>=' },
{ value: 'CONTAINS', label: '包含' },
{ value: 'NOT_CONTAINS', label: '不包含' },
{ value: 'STARTS_WITH', label: '开始以' },
{ value: 'NOT_STARTS_WITH', label: '不是开始于' },
{ value: 'ENDS_WITH', label: '结束以' },
{ value: 'NOT_ENDS_WITH', label: '不是结束于' },
{ value: 'IS_NULL', label: '是 null' },
{ value: 'IS_NOT_NULL', label: '不是 null' },
{ value: 'IS_EMPTY', label: '是空的' },
{ value: 'IS_NOT_EMPTY', label: '不是空的' },
{ value: 'BETWEEN', label: '介于' },
{ value: 'NOT_BETWEEN', label: '不介于' },
{ value: 'IN', label: '在列表' },
{ value: 'NOT_IN', label: '不在列表' },
{ value: 'CUSTOM', label: '[自定义]' },
]), []);
const isNoValueOp = useCallback((op: string) => (
op === 'IS_NULL' || op === 'IS_NOT_NULL' || op === 'IS_EMPTY' || op === 'IS_NOT_EMPTY'
), []);
const isBetweenOp = useCallback((op: string) => op === 'BETWEEN' || op === 'NOT_BETWEEN', []);
const isListOp = useCallback((op: string) => op === 'IN' || op === 'NOT_IN', []);
const addFilter = () => {
setFilterConditions([...filterConditions, { id: nextFilterId, column: columnNames[0] || '', op: '=', value: '' }]);
setFilterConditions([...filterConditions, { id: nextFilterId, column: columnNames[0] || '', op: '=', value: '', value2: '' }]);
setNextFilterId(nextFilterId + 1);
};
const updateFilter = (id: number, field: string, val: string) => {
setFilterConditions(prev => prev.map(c => c.id === id ? { ...c, [field]: val } : c));
setFilterConditions(prev => prev.map(c => {
if (c.id !== id) return c;
const next: any = { ...c, [field]: val };
if (field === 'op') {
if (isNoValueOp(val)) {
next.value = '';
next.value2 = '';
} else if (isBetweenOp(val)) {
if (typeof next.value2 !== 'string') next.value2 = '';
} else {
next.value2 = '';
}
}
return next;
}));
};
const removeFilter = (id: number) => {
setFilterConditions(prev => prev.filter(c => c.id !== id));
@@ -723,33 +1111,41 @@ const DataGrid: React.FC<DataGridProps> = ({
header: { cell: ResizableTitle }
}), []);
const totalWidth = columns.reduce((sum, col) => sum + (col.width as number || 200), 0);
const totalWidth = columns.reduce((sum, col) => sum + (Number(col.width) || 200), 0) + selectionColumnWidth;
const enableVirtual = mergedDisplayData.length >= 200;
return (
<div className={gridId} style={{ height: '100%', overflow: 'hidden', padding: 0, display: 'flex', flexDirection: 'column', minHeight: 0 }}>
{/* Toolbar */}
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: 8, alignItems: 'center' }}>
{onReload && <Button icon={<ReloadOutlined />} onClick={() => {
setAddedRows([]);
setModifiedRows({});
setDeletedRowKeys(new Set());
setSelectedRowKeys([]);
onReload();
}}></Button>}
{tableName && <Button icon={<ImportOutlined />} onClick={handleImport}></Button>}
{tableName && <Dropdown menu={{ items: exportMenu }}><Button icon={<ExportOutlined />}> <DownOutlined /></Button></Dropdown>}
{!readOnly && tableName && (
<>
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
<Button icon={<PlusOutlined />} onClick={handleAddRow}></Button>
<Button icon={<DeleteOutlined />} danger disabled={selectedRowKeys.length === 0} onClick={handleDeleteSelected}></Button>
{selectedRowKeys.length > 0 && <span style={{ fontSize: '12px', color: '#888' }}> {selectedRowKeys.length}</span>}
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
<Button icon={<SaveOutlined />} type="primary" disabled={!hasChanges} onClick={handleCommit}> ({addedRows.length + Object.keys(modifiedRows).length + deletedRowKeys.size})</Button>
{hasChanges && (<Button icon={<UndoOutlined />} onClick={() => {
setAddedRows([]);
<div className={gridId} style={{ flex: '1 1 auto', height: '100%', overflow: 'hidden', padding: 0, display: 'flex', flexDirection: 'column', minHeight: 0 }}>
{/* Toolbar */}
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: 8, alignItems: 'center' }}>
{onReload && <Button icon={<ReloadOutlined />} onClick={() => {
setAddedRows([]);
setModifiedRows({});
setDeletedRowKeys(new Set());
setSelectedRowKeys([]);
setActiveCell(null);
onReload();
}}></Button>}
{tableName && <Button icon={<ImportOutlined />} onClick={handleImport}></Button>}
{tableName && <Dropdown menu={{ items: exportMenu }}><Button icon={<ExportOutlined />}> <DownOutlined /></Button></Dropdown>}
{!readOnly && tableName && (
<>
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
<Button icon={<PlusOutlined />} onClick={handleAddRow}></Button>
<Button
icon={<EditOutlined />}
disabled={selectedRowKeys.length > 1 || (selectedRowKeys.length !== 1 && !activeCell)}
onClick={openRowEditor}
>
</Button>
<Button icon={<DeleteOutlined />} danger disabled={selectedRowKeys.length === 0} onClick={handleDeleteSelected}></Button>
{selectedRowKeys.length > 0 && <span style={{ fontSize: '12px', color: '#888' }}> {selectedRowKeys.length}</span>}
<div style={{ width: 1, background: '#eee', height: 20, margin: '0 8px' }} />
<Button icon={<SaveOutlined />} type="primary" disabled={!hasChanges} onClick={handleCommit}> ({addedRows.length + Object.keys(modifiedRows).length + deletedRowKeys.size})</Button>
{hasChanges && (<Button icon={<UndoOutlined />} onClick={() => {
setAddedRows([]);
setModifiedRows({});
setDeletedRowKeys(new Set());
}}></Button>)}
@@ -771,10 +1167,62 @@ const DataGrid: React.FC<DataGridProps> = ({
{showFilter && (
<div style={{ padding: '8px', background: '#f5f5f5', borderBottom: '1px solid #eee' }}>
{filterConditions.map(cond => (
<div key={cond.id} style={{ display: 'flex', gap: 8, marginBottom: 8 }}>
<Select style={{ width: 150 }} value={cond.column} onChange={v => updateFilter(cond.id, 'column', v)} options={columnNames.map(c => ({ value: c, label: c }))} />
<Select style={{ width: 100 }} value={cond.op} onChange={v => updateFilter(cond.id, 'op', v)} options={[{ value: '=', label: '=' }, { value: 'LIKE', label: '包含' }]} />
<Input style={{ width: 200 }} value={cond.value} onChange={e => updateFilter(cond.id, 'value', e.target.value)} />
<div key={cond.id} style={{ display: 'flex', gap: 8, marginBottom: 8, alignItems: 'flex-start' }}>
<Select
style={{ width: 180 }}
value={cond.column}
onChange={v => updateFilter(cond.id, 'column', v)}
options={columnNames.map(c => ({ value: c, label: c }))}
disabled={cond.op === 'CUSTOM'}
/>
<Select
style={{ width: 140 }}
value={cond.op}
onChange={v => updateFilter(cond.id, 'op', v)}
options={filterOpOptions as any}
/>
{cond.op === 'CUSTOM' ? (
<Input.TextArea
style={{ flex: 1 }}
autoSize={{ minRows: 1, maxRows: 4 }}
value={cond.value}
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
placeholder="输入自定义 WHERE 表达式(不需要再写 WHERE例如status IN ('A','B')"
/>
) : isListOp(cond.op) ? (
<Input.TextArea
style={{ flex: 1 }}
autoSize={{ minRows: 1, maxRows: 4 }}
value={cond.value}
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
placeholder="多个值用逗号或换行分隔"
/>
) : isBetweenOp(cond.op) ? (
<>
<Input
style={{ width: 220 }}
value={cond.value}
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
placeholder="开始值"
/>
<Input
style={{ width: 220 }}
value={cond.value2 || ''}
onChange={e => updateFilter(cond.id, 'value2', e.target.value)}
placeholder="结束值"
/>
</>
) : isNoValueOp(cond.op) ? (
<Input style={{ width: 220 }} value="" disabled placeholder="无需输入值" />
) : (
<Input
style={{ width: 280 }}
value={cond.value}
onChange={e => updateFilter(cond.id, 'value', e.target.value)}
/>
)}
<Button icon={<CloseOutlined />} onClick={() => removeFilter(cond.id)} type="text" danger />
</div>
))}
@@ -789,8 +1237,90 @@ const DataGrid: React.FC<DataGridProps> = ({
</div>
)}
<div ref={containerRef} style={{ flex: 1, overflow: 'hidden', position: 'relative', minHeight: 0 }}>
{contextHolder}
<div ref={containerRef} style={{ flex: 1, overflow: 'hidden', position: 'relative', minHeight: 0 }}>
{contextHolder}
<Modal
title="编辑行"
open={rowEditorOpen}
onCancel={closeRowEditor}
width={980}
destroyOnClose
maskClosable={false}
footer={[
<Button key="cancel" onClick={closeRowEditor}></Button>,
<Button key="ok" type="primary" onClick={applyRowEditor}></Button>,
]}
>
<div style={{ marginBottom: 8, color: '#888', fontSize: 12, display: 'flex', justifyContent: 'space-between', gap: 8 }}>
<span>{tableName ? `${tableName}` : ''}</span>
<span>{rowEditorRowKey ? `rowKey: ${rowEditorRowKey}` : ''}</span>
</div>
<Form form={rowEditorForm} layout="vertical">
<div style={{ maxHeight: '62vh', overflow: 'auto', paddingRight: 8 }}>
{columnNames.map((col) => {
const sample = rowEditorDisplayRef.current?.[col] ?? '';
const placeholder = rowEditorNullColsRef.current?.has(col) ? '(NULL)' : undefined;
const isJson = looksLikeJsonText(sample);
const useArea = isJson || sample.includes('\n') || sample.length >= 160;
return (
<Form.Item key={col} label={col} style={{ marginBottom: 12 }}>
<div style={{ display: 'flex', gap: 8, alignItems: 'flex-start' }}>
<Form.Item name={col} noStyle>
{useArea ? (
<Input.TextArea
style={{ flex: 1 }}
autoSize={{ minRows: isJson ? 4 : 1, maxRows: 10 }}
placeholder={placeholder}
/>
) : (
<Input style={{ flex: 1 }} placeholder={placeholder} />
)}
</Form.Item>
<Button size="small" onClick={() => openRowEditorFieldEditor(col)} title="弹窗编辑">...</Button>
</div>
</Form.Item>
);
})}
</div>
</Form>
</Modal>
<Modal
title={cellEditorMeta ? `编辑单元格:${cellEditorMeta.title}` : '编辑单元格'}
open={cellEditorOpen}
onCancel={closeCellEditor}
width={960}
destroyOnClose
maskClosable={false}
footer={[
<Button key="format" onClick={handleFormatJsonInEditor} disabled={!cellEditorIsJson}>
JSON
</Button>,
<Button key="cancel" onClick={closeCellEditor}></Button>,
<Button key="ok" type="primary" onClick={handleCellEditorSave}></Button>,
]}
>
<div style={{ marginBottom: 8, color: '#888', fontSize: 12 }}>
{cellEditorMeta ? `${tableName || ''}${tableName ? '.' : ''}${cellEditorMeta.dataIndex}` : ''}
</div>
{cellEditorOpen && (
<Editor
height="56vh"
language={cellEditorIsJson ? "json" : "plaintext"}
theme={darkMode ? "vs-dark" : "light"}
value={cellEditorValue}
onChange={(val) => setCellEditorValue(val || '')}
options={{
minimap: { enabled: false },
scrollBeyondLastLine: false,
wordWrap: "on",
fontSize: 14,
tabSize: 2,
automaticLayout: true,
}}
/>
)}
</Modal>
<Form component={false} form={form}>
<DataContext.Provider value={{ selectedRowKeysRef, displayDataRef, handleCopyInsert, handleCopyJson, handleCopyCsv, handleExportSelected, copyToClipboard, tableName }}>
<EditableContext.Provider value={form}>
@@ -799,6 +1329,7 @@ const DataGrid: React.FC<DataGridProps> = ({
dataSource={mergedDisplayData}
columns={mergedColumns}
size="small"
tableLayout="fixed"
scroll={{ x: Math.max(totalWidth, 1000), y: tableHeight }}
virtual={enableVirtual}
loading={loading}
@@ -809,6 +1340,7 @@ const DataGrid: React.FC<DataGridProps> = ({
rowSelection={{
selectedRowKeys,
onChange: setSelectedRowKeys,
columnWidth: selectionColumnWidth,
}}
rowClassName={(record) => {
const k = record?.[GONAVI_ROW_KEY];
@@ -842,14 +1374,14 @@ const DataGrid: React.FC<DataGridProps> = ({
</div>
)}
<style>{`
.${gridId} .row-added td { background-color: #f6ffed !important; }
.${gridId} .row-modified td { background-color: #e6f7ff !important; }
.${gridId} .ant-table-body {
height: ${tableHeight}px !important;
max-height: ${tableHeight}px !important;
}
`}</style>
<style>{`
.${gridId} .row-added td { background-color: #f6ffed !important; }
.${gridId} .row-modified td { background-color: #e6f7ff !important; }
.${gridId} td.gonavi-active-cell {
outline: 2px solid #1677ff;
outline-offset: -2px;
}
`}</style>
{/* Ghost Resize Line for Columns */}
<div

View File

@@ -4,6 +4,7 @@ import { TabData, ColumnDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildWhereSQL, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [data, setData] = useState<any[]>([]);
@@ -14,6 +15,8 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const fetchSeqRef = useRef(0);
const countSeqRef = useRef(0);
const countKeyRef = useRef<string>('');
const pkSeqRef = useRef(0);
const pkKeyRef = useRef<string>('');
const [pagination, setPagination] = useState({
current: 1,
@@ -27,6 +30,13 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [showFilter, setShowFilter] = useState(false);
const [filterConditions, setFilterConditions] = useState<any[]>([]);
useEffect(() => {
setPkColumns([]);
pkKeyRef.current = '';
countKeyRef.current = '';
setPagination(prev => ({ ...prev, current: 1, total: 0, totalKnown: false }));
}, [tab.connectionId, tab.dbName, tab.tableName]);
const fetchData = useCallback(async (page = pagination.current, size = pagination.pageSize) => {
const seq = ++fetchSeqRef.current;
setLoading(true);
@@ -46,40 +56,18 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
};
const quoteIdentPart = (ident: string) => {
if (!ident) return ident;
if (config.type === 'mysql') return `\`${ident.replace(/`/g, '``')}\``;
return `"${ident.replace(/"/g, '""')}"`;
};
const quoteQualifiedIdent = (ident: string) => {
const raw = (ident || '').trim();
if (!raw) return raw;
const parts = raw.split('.').filter(Boolean);
if (parts.length <= 1) return quoteIdentPart(raw);
return parts.map(quoteIdentPart).join('.');
};
const escapeLiteral = (val: string) => val.replace(/'/g, "''");
const dbType = config.type || '';
const dbName = tab.dbName || '';
const tableName = tab.tableName || '';
const whereParts: string[] = [];
filterConditions.forEach(cond => {
if (cond.column && cond.value) {
if (cond.op === 'LIKE') {
whereParts.push(`${quoteIdentPart(cond.column)} LIKE '%${escapeLiteral(cond.value)}%'`);
} else {
whereParts.push(`${quoteIdentPart(cond.column)} ${cond.op} '${escapeLiteral(cond.value)}'`);
}
}
});
const whereSQL = whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : "";
const whereSQL = buildWhereSQL(dbType, filterConditions);
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(tableName)} ${whereSQL}`;
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
let sql = `SELECT * FROM ${quoteQualifiedIdent(tableName)} ${whereSQL}`;
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
if (sortInfo && sortInfo.order) {
sql += ` ORDER BY ${quoteIdentPart(sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
sql += ` ORDER BY ${quoteIdentPart(dbType, sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
}
const offset = (page - 1) * size;
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
@@ -89,11 +77,6 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
try {
const pData = DBQuery(config as any, dbName, sql);
let pCols: Promise<any> | null = null;
if (pkColumns.length === 0) {
pCols = DBGetColumns(config as any, dbName, tableName);
}
const resData = await pData;
const duration = Date.now() - startTime;
@@ -109,11 +92,23 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
dbName
});
if (pCols) {
const resCols = await pCols;
if (resCols.success) {
const pks = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
setPkColumns(pks);
if (pkColumns.length === 0) {
const pkKey = `${tab.connectionId}|${dbName}|${tableName}`;
if (pkKeyRef.current !== pkKey) {
pkKeyRef.current = pkKey;
const pkSeq = ++pkSeqRef.current;
DBGetColumns(config as any, dbName, tableName)
.then((resCols: any) => {
if (pkSeqRef.current !== pkSeq) return;
if (pkKeyRef.current !== pkKey) return;
if (!resCols?.success) return;
const pks = (resCols.data as ColumnDefinition[]).filter((c: any) => c.key === 'PRI').map((c: any) => c.name);
setPkColumns(pks);
})
.catch(() => {
if (pkSeqRef.current !== pkSeq) return;
if (pkKeyRef.current !== pkKey) return;
});
}
}
@@ -227,7 +222,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
}, [tab, sortInfo, filterConditions]); // Initial load and re-load on sort/filter
return (
<div style={{ height: '100%', width: '100%', overflow: 'hidden' }}>
<div style={{ flex: '1 1 auto', minHeight: 0, height: '100%', width: '100%', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
<DataGrid
data={data}
columnNames={columnNames}

File diff suppressed because it is too large Load Diff

View File

@@ -47,6 +47,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
const [autoExpandParent, setAutoExpandParent] = useState(true);
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
const [selectedNodes, setSelectedNodes] = useState<any[]>([]);
const [contextMenu, setContextMenu] = useState<{ x: number, y: number, items: MenuProps['items'] } | null>(null);
// Virtual Scroll State
@@ -283,10 +285,14 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
};
const onSelect = (keys: React.Key[], info: any) => {
if (!info.node.selected) {
setSelectedKeys(keys);
setSelectedNodes(info.selectedNodes || []);
if (keys.length === 0) {
setActiveContext(null);
return;
}
if (!info.selected) return;
const { type, dataRef, key, title } = info.node;
@@ -313,15 +319,6 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
};
const onDoubleClick = (e: any, node: any) => {
const key = node.key;
const isExpanded = expandedKeys.includes(key);
const newExpandedKeys = isExpanded
? expandedKeys.filter(k => k !== key)
: [...expandedKeys, key];
setExpandedKeys(newExpandedKeys);
if (!isExpanded) setAutoExpandParent(false);
if (node.type === 'table') {
const { tableName, dbName, id } = node.dataRef;
addTab({
@@ -332,6 +329,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
dbName,
tableName,
});
return;
} else if (node.type === 'saved-query') {
const q = node.dataRef;
addTab({
@@ -342,7 +340,17 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
dbName: q.dbName,
query: q.sql
});
return;
}
const key = node.key;
const isExpanded = expandedKeys.includes(key);
const newExpandedKeys = isExpanded
? expandedKeys.filter(k => k !== key)
: [...expandedKeys, key];
setExpandedKeys(newExpandedKeys);
if (!isExpanded) setAutoExpandParent(false);
};
const handleCopyStructure = async (node: any) => {
@@ -382,6 +390,60 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
}
};
const normalizeConnConfig = (raw: any) => ({
...raw,
port: Number(raw.port),
password: raw.password || "",
database: raw.database || "",
useSSH: raw.useSSH || false,
ssh: raw.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
});
const handleExportDatabaseSQL = async (node: any, includeData: boolean) => {
const conn = node.dataRef;
const dbName = conn.dbName || node.title;
const hide = message.loading(includeData ? `正在备份数据库 ${dbName} (结构+数据)...` : `正在导出数据库 ${dbName} 表结构...`, 0);
try {
const res = await (window as any).go.app.App.ExportDatabaseSQL(normalizeConnConfig(conn.config), dbName, includeData);
hide();
if (res.success) {
message.success('导出成功');
} else if (res.message !== 'Cancelled') {
message.error('导出失败: ' + res.message);
}
} catch (e: any) {
hide();
message.error('导出失败: ' + (e?.message || String(e)));
}
};
const handleExportTablesSQL = async (nodes: any[], includeData: boolean) => {
if (!nodes || nodes.length === 0) return;
const first = nodes[0].dataRef;
const dbName = first.dbName;
const connId = first.id;
const allSame = nodes.every(n => n?.dataRef?.id === connId && n?.dataRef?.dbName === dbName);
if (!allSame) {
message.error('请在同一连接、同一数据库下选择多张表进行导出');
return;
}
const tableNames = nodes.map(n => n.dataRef.tableName).filter(Boolean);
const hide = message.loading(includeData ? `正在备份选中表 (${tableNames.length})...` : `正在导出选中表结构 (${tableNames.length})...`, 0);
try {
const res = await (window as any).go.app.App.ExportTablesSQL(normalizeConnConfig(first.config), dbName, tableNames, includeData);
hide();
if (res.success) {
message.success('导出成功');
} else if (res.message !== 'Cancelled') {
message.error('导出失败: ' + res.message);
}
} catch (e: any) {
hide();
message.error('导出失败: ' + (e?.message || String(e)));
}
};
const handleRunSQLFile = async (node: any) => {
const res = await (window as any).go.app.App.OpenSQLFile();
if (res.success) {
@@ -550,6 +612,18 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
icon: <ReloadOutlined />,
onClick: () => loadTables(node)
},
{
key: 'export-db-schema',
label: '导出全部表结构 (SQL)',
icon: <ExportOutlined />,
onClick: () => handleExportDatabaseSQL(node, false)
},
{
key: 'backup-db-sql',
label: '备份全部表 (结构+数据 SQL)',
icon: <SaveOutlined />,
onClick: () => handleExportDatabaseSQL(node, true)
},
{ type: 'divider' },
{
key: 'disconnect-db',
@@ -588,7 +662,25 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
}
];
} else if (node.type === 'table') {
const sameContextSelectedTables = (selectedNodes || []).filter((n: any) => n?.type === 'table' && n?.dataRef?.id === node?.dataRef?.id && n?.dataRef?.dbName === node?.dataRef?.dbName);
const selectedForAction = sameContextSelectedTables.some((n: any) => n?.key === node.key) ? sameContextSelectedTables : [node];
return [
...(selectedForAction.length > 1 ? ([
{
key: 'export-selected-schema',
label: `导出选中表结构 (${selectedForAction.length}) (SQL)`,
icon: <ExportOutlined />,
onClick: () => handleExportTablesSQL(selectedForAction, false)
},
{
key: 'backup-selected-sql',
label: `备份选中表 (${selectedForAction.length}) (结构+数据 SQL)`,
icon: <SaveOutlined />,
onClick: () => handleExportTablesSQL(selectedForAction, true)
},
{ type: 'divider' as const }
]) : []),
{
key: 'new-query',
label: '新建查询',
@@ -684,6 +776,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
loadedKeys={loadedKeys}
onLoad={setLoadedKeys}
autoExpandParent={autoExpandParent}
multiple
selectedKeys={selectedKeys}
blockNode
height={treeHeight}
onRightClick={onRightClick}

View File

@@ -1,12 +1,13 @@
import React, { useMemo } from 'react';
import { Tabs, Button } from 'antd';
import { Tabs, Dropdown } from 'antd';
import type { MenuProps } from 'antd';
import { useStore } from '../store';
import DataViewer from './DataViewer';
import QueryEditor from './QueryEditor';
import TableDesigner from './TableDesigner';
const TabManager: React.FC = () => {
const { tabs, activeTabId, setActiveTab, closeTab } = useStore();
const { tabs, activeTabId, setActiveTab, closeTab, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs } = useStore();
const onChange = (newActiveKey: string) => {
setActiveTab(newActiveKey);
@@ -18,7 +19,7 @@ const TabManager: React.FC = () => {
}
};
const items = useMemo(() => tabs.map(tab => {
const items = useMemo(() => tabs.map((tab, index) => {
let content;
if (tab.type === 'query') {
content = <QueryEditor tab={tab} />;
@@ -27,27 +28,95 @@ const TabManager: React.FC = () => {
} else if (tab.type === 'design') {
content = <TableDesigner tab={tab} />;
}
const menuItems: MenuProps['items'] = [
{
key: 'close-other',
label: '关闭其他页',
disabled: tabs.length <= 1,
onClick: () => closeOtherTabs(tab.id),
},
{
key: 'close-left',
label: '关闭左侧',
disabled: index === 0,
onClick: () => closeTabsToLeft(tab.id),
},
{
key: 'close-right',
label: '关闭右侧',
disabled: index === tabs.length - 1,
onClick: () => closeTabsToRight(tab.id),
},
{ type: 'divider' },
{
key: 'close-all',
label: '关闭所有',
disabled: tabs.length === 0,
onClick: () => closeAllTabs(),
},
];
return {
label: tab.title,
label: (
<Dropdown menu={{ items: menuItems }} trigger={['contextMenu']}>
<span onContextMenu={(e) => e.preventDefault()}>{tab.title}</span>
</Dropdown>
),
key: tab.id,
children: content,
};
}), [tabs]);
}), [tabs, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
return (
<>
<style>{`
.ant-tabs-content { height: 100%; }
.ant-tabs-tabpane { height: 100%; }
.main-tabs {
height: 100%;
flex: 1 1 auto;
min-height: 0;
display: flex;
flex-direction: column;
overflow: hidden;
}
.main-tabs .ant-tabs-nav {
flex: 0 0 auto;
}
.main-tabs .ant-tabs-content-holder {
flex: 1 1 auto;
min-height: 0;
overflow: hidden;
display: flex;
flex-direction: column;
}
.main-tabs .ant-tabs-content {
flex: 1 1 auto;
min-height: 0;
display: flex;
flex-direction: column;
}
.main-tabs .ant-tabs-tabpane {
flex: 1 1 auto;
min-height: 0;
display: flex;
flex-direction: column;
overflow: hidden;
}
.main-tabs .ant-tabs-tabpane > div {
flex: 1 1 auto;
min-height: 0;
}
.main-tabs .ant-tabs-tabpane-hidden {
display: none !important;
}
`}</style>
<Tabs
className="main-tabs"
type="editable-card"
onChange={onChange}
activeKey={activeTabId || undefined}
onEdit={onEdit}
items={items}
style={{ height: '100%' }}
hideAdd
/>
</>

View File

@@ -550,7 +550,6 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
<div ref={containerRef} className="table-designer-wrapper" style={{ height: '100%', overflow: 'hidden', position: 'relative' }}>
<style>{`
.table-designer-wrapper .ant-table-body {
height: ${tableHeight}px !important;
max-height: ${tableHeight}px !important;
}
`}</style>

View File

@@ -21,6 +21,7 @@ interface AppState {
savedQueries: SavedQuery[];
darkMode: boolean;
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
queryOptions: { maxRows: number };
sqlLogs: SqlLog[];
addConnection: (conn: SavedConnection) => void;
@@ -29,6 +30,10 @@ interface AppState {
addTab: (tab: TabData) => void;
closeTab: (id: string) => void;
closeOtherTabs: (id: string) => void;
closeTabsToLeft: (id: string) => void;
closeTabsToRight: (id: string) => void;
closeAllTabs: () => void;
setActiveTab: (id: string) => void;
setActiveContext: (context: { connectionId: string; dbName: string } | null) => void;
@@ -37,6 +42,7 @@ interface AppState {
toggleDarkMode: () => void;
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
setQueryOptions: (options: Partial<{ maxRows: number }>) => void;
addSqlLog: (log: SqlLog) => void;
clearSqlLogs: () => void;
@@ -52,6 +58,7 @@ export const useStore = create<AppState>()(
savedQueries: [],
darkMode: false,
sqlFormatOptions: { keywordCase: 'upper' },
queryOptions: { maxRows: 5000 },
sqlLogs: [],
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
@@ -79,6 +86,30 @@ export const useStore = create<AppState>()(
}
return { tabs: newTabs, activeTabId: newActiveId };
}),
closeOtherTabs: (id) => set((state) => {
const keep = state.tabs.find(t => t.id === id);
if (!keep) return state;
return { tabs: [keep], activeTabId: id };
}),
closeTabsToLeft: (id) => set((state) => {
const index = state.tabs.findIndex(t => t.id === id);
if (index === -1) return state;
const newTabs = state.tabs.slice(index);
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
}),
closeTabsToRight: (id) => set((state) => {
const index = state.tabs.findIndex(t => t.id === id);
if (index === -1) return state;
const newTabs = state.tabs.slice(0, index + 1);
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
}),
closeAllTabs: () => set(() => ({ tabs: [], activeTabId: null })),
setActiveTab: (id) => set({ activeTabId: id }),
setActiveContext: (context) => set({ activeContext: context }),
@@ -96,13 +127,14 @@ export const useStore = create<AppState>()(
toggleDarkMode: () => set((state) => ({ darkMode: !state.darkMode })),
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
clearSqlLogs: () => set({ sqlLogs: [] }),
}),
{
name: 'lite-db-storage', // name of the item in the storage (must be unique)
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions }), // Don't persist logs
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions, queryOptions: state.queryOptions }), // Don't persist logs
}
)
);
);

173
frontend/src/utils/sql.ts Normal file
View File

@@ -0,0 +1,173 @@
export type FilterCondition = {
id?: number;
column?: string;
op?: string;
value?: string;
value2?: string;
};
const normalizeIdentPart = (ident: string) => {
let raw = (ident || '').trim();
if (!raw) return raw;
const first = raw[0];
const last = raw[raw.length - 1];
if ((first === '"' && last === '"') || (first === '`' && last === '`')) {
raw = raw.slice(1, -1).trim();
}
raw = raw.replace(/["`]/g, '').trim();
return raw;
};
export const quoteIdentPart = (dbType: string, ident: string) => {
const raw = normalizeIdentPart(ident);
if (!raw) return raw;
if ((dbType || '').toLowerCase() === 'mysql') return `\`${raw.replace(/`/g, '``')}\``;
return `"${raw.replace(/"/g, '""')}"`;
};
export const quoteQualifiedIdent = (dbType: string, ident: string) => {
const raw = (ident || '').trim();
if (!raw) return raw;
const parts = raw.split('.').map(normalizeIdentPart).filter(Boolean);
if (parts.length <= 1) return quoteIdentPart(dbType, raw);
return parts.map(p => quoteIdentPart(dbType, p)).join('.');
};
export const escapeLiteral = (val: string) => (val || '').replace(/'/g, "''");
export const parseListValues = (val: string) => {
const raw = (val || '').trim();
if (!raw) return [];
return raw
.split(/[\n,]+/)
.map(s => s.trim())
.filter(Boolean);
};
export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) => {
const whereParts: string[] = [];
(conditions || []).forEach((cond) => {
const op = (cond?.op || '').trim();
const column = (cond?.column || '').trim();
const value = (cond?.value ?? '').toString();
const value2 = (cond?.value2 ?? '').toString();
if (op === 'CUSTOM') {
const expr = value.trim();
if (expr) whereParts.push(`(${expr})`);
return;
}
if (!column) return;
const col = quoteIdentPart(dbType, column);
switch (op) {
case 'IS_NULL':
whereParts.push(`${col} IS NULL`);
return;
case 'IS_NOT_NULL':
whereParts.push(`${col} IS NOT NULL`);
return;
case 'IS_EMPTY':
// 兼容:空值通常理解为 NULL 或空字符串
whereParts.push(`(${col} IS NULL OR ${col} = '')`);
return;
case 'IS_NOT_EMPTY':
whereParts.push(`(${col} IS NOT NULL AND ${col} <> '')`);
return;
case 'BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'NOT_BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} IN (${list})`);
return;
}
case 'NOT_IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} NOT IN (${list})`);
return;
}
case 'CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'NOT_CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'NOT_STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}'`);
return;
}
case 'NOT_ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
return;
}
case '=':
case '!=':
case '<':
case '<=':
case '>':
case '>=': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
return;
}
default: {
// 兼容旧值LIKE
if (op.toUpperCase() === 'LIKE') {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
}
}
});
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
};

View File

@@ -35,8 +35,14 @@ export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Pr
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise<connection.QueryResult>;
export function ExportQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string,arg5:string):Promise<connection.QueryResult>;
export function ExportTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function ExportTablesSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>,arg4:boolean):Promise<connection.QueryResult>;
export function ImportConfigFile():Promise<connection.QueryResult>;
export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;

View File

@@ -66,10 +66,22 @@ export function ExportData(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4);
}
export function ExportDatabaseSQL(arg1, arg2, arg3) {
return window['go']['app']['App']['ExportDatabaseSQL'](arg1, arg2, arg3);
}
export function ExportQuery(arg1, arg2, arg3, arg4, arg5) {
return window['go']['app']['App']['ExportQuery'](arg1, arg2, arg3, arg4, arg5);
}
export function ExportTable(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportTable'](arg1, arg2, arg3, arg4);
}
export function ExportTablesSQL(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportTablesSQL'](arg1, arg2, arg3, arg4);
}
export function ImportConfigFile() {
return window['go']['app']['App']['ImportConfigFile']();
}

View File

@@ -141,9 +141,6 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
if len(shortKey) > 12 {
shortKey = shortKey[:12]
}
if config.UseSSH && config.Type != "mysql" {
logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config))
}
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
a.mu.Lock()

View File

@@ -1,11 +1,14 @@
package app
import (
"context"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/utils"
)
// Generic DB Methods
@@ -91,16 +94,39 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
return connection.QueryResult{Success: false, Message: err.Error()}
}
query = sanitizeSQLForPgLike(runConfig.Type, query)
timeoutSeconds := runConfig.Timeout
if timeoutSeconds <= 0 {
timeoutSeconds = 30
}
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
defer cancel()
lowerQuery := strings.TrimSpace(strings.ToLower(query))
if strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain") {
data, columns, err := dbInst.Query(query)
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, query)
} else {
data, columns, err = dbInst.Query(query)
}
if err != nil {
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: data, Fields: columns}
} else {
affected, err := dbInst.Exec(query)
var affected int64
if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, query)
} else {
affected, err = dbInst.Exec(query)
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}

View File

@@ -1,11 +1,16 @@
package app
import (
"bufio"
"encoding/csv"
"encoding/json"
"fmt"
"math"
"os"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
@@ -213,12 +218,36 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
format = strings.ToLower(format)
if format == "sql" {
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
w := bufio.NewWriterSize(f, 1024*1024)
defer w.Flush()
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := dumpTableSQL(w, dbInst, runConfig, dbName, tableName, true); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := writeSQLFooter(w, runConfig); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(runConfig.Type, tableName))
data, columns, err := dbInst.Query(query)
@@ -231,71 +260,129 @@ data, columns, err := dbInst.Query(query)
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
if err := writeRowsToFile(f, data, columns, format); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
format = strings.ToLower(format)
var csvWriter *csv.Writer
var jsonEncoder *json.Encoder
var isJsonFirstRow = true
return connection.QueryResult{Success: true, Message: "Export successful"}
}
switch format {
case "csv", "xlsx":
f.Write([]byte{0xEF, 0xBB, 0xBF})
csvWriter = csv.NewWriter(f)
defer csvWriter.Flush()
if err := csvWriter.Write(columns); err != nil {
func (a *App) ExportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeData bool) connection.QueryResult {
safeDbName := strings.TrimSpace(dbName)
if safeDbName == "" {
safeDbName = "export"
}
suffix := "schema"
if includeData {
suffix = "backup"
}
defaultFilename := fmt.Sprintf("%s_%s_%dtables.sql", safeDbName, suffix, len(tableNames))
if len(tableNames) == 1 && strings.TrimSpace(tableNames[0]) != "" {
defaultFilename = fmt.Sprintf("%s_%s.sql", strings.TrimSpace(tableNames[0]), suffix)
}
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: "Export Tables (SQL)",
DefaultFilename: defaultFilename,
})
if err != nil || filename == "" {
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
tables := make([]string, 0, len(tableNames))
seen := make(map[string]struct{}, len(tableNames))
for _, t := range tableNames {
t = strings.TrimSpace(t)
if t == "" {
continue
}
if _, ok := seen[t]; ok {
continue
}
seen[t] = struct{}{}
tables = append(tables, t)
}
sort.Strings(tables)
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
w := bufio.NewWriterSize(f, 1024*1024)
defer w.Flush()
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
for _, t := range tables {
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
case "json":
f.WriteString("[\n")
jsonEncoder = json.NewEncoder(f)
jsonEncoder.SetIndent(" ", " ")
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
seps := make([]string, len(columns))
for i := range seps {
seps[i] = "---"
}
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
default:
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
}
if err := writeSQLFooter(w, runConfig); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
for _, rowMap := range data {
record := make([]string, len(columns))
for i, col := range columns {
val := rowMap[col]
if val == nil {
record[i] = "NULL"
} else {
s := fmt.Sprintf("%v", val)
if format == "md" {
s = strings.ReplaceAll(s, "|", "\\|")
s = strings.ReplaceAll(s, "\n", "<br>")
}
record[i] = s
}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
switch format {
case "csv", "xlsx":
if err := csvWriter.Write(record); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
case "json":
if !isJsonFirstRow {
f.WriteString(",\n")
}
if err := jsonEncoder.Encode(rowMap); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
isJsonFirstRow = false
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
}
func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName string, includeData bool) connection.QueryResult {
safeDbName := strings.TrimSpace(dbName)
if safeDbName == "" {
return connection.QueryResult{Success: false, Message: "dbName required"}
}
suffix := "schema"
if includeData {
suffix = "backup"
}
if format == "json" {
f.WriteString("\n]")
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: fmt.Sprintf("Export %s (SQL)", safeDbName),
DefaultFilename: fmt.Sprintf("%s_%s.sql", safeDbName, suffix),
})
if err != nil || filename == "" {
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
tables, err := dbInst.GetTables(dbName)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
sort.Strings(tables)
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
w := bufio.NewWriterSize(f, 1024*1024)
defer w.Flush()
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
for _, t := range tables {
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
}
if err := writeSQLFooter(w, runConfig); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
@@ -340,6 +427,173 @@ func quoteQualifiedIdentByType(dbType string, ident string) string {
return strings.Join(quotedParts, ".")
}
func writeSQLHeader(w *bufio.Writer, config connection.ConnectionConfig, dbName string) error {
now := time.Now().Format("2006-01-02 15:04:05")
if _, err := w.WriteString(fmt.Sprintf("-- GoNavi SQL Export\n-- Time: %s\n", now)); err != nil {
return err
}
if strings.TrimSpace(dbName) != "" {
if _, err := w.WriteString(fmt.Sprintf("-- Database: %s\n\n", dbName)); err != nil {
return err
}
}
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" && strings.TrimSpace(dbName) != "" {
if _, err := w.WriteString(fmt.Sprintf("USE %s;\n\n", quoteIdentByType("mysql", dbName))); err != nil {
return err
}
if _, err := w.WriteString("SET FOREIGN_KEY_CHECKS=0;\n\n"); err != nil {
return err
}
}
return nil
}
func writeSQLFooter(w *bufio.Writer, config connection.ConnectionConfig) error {
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" {
if _, err := w.WriteString("\nSET FOREIGN_KEY_CHECKS=1;\n"); err != nil {
return err
}
}
return nil
}
func qualifyTable(schemaName, tableName string) string {
schemaName = strings.TrimSpace(schemaName)
tableName = strings.TrimSpace(tableName)
if schemaName == "" {
return tableName
}
return schemaName + "." + tableName
}
func ensureSQLTerminator(sql string) string {
trimmed := strings.TrimSpace(sql)
if trimmed == "" {
return sql
}
if strings.HasSuffix(trimmed, ";") {
return sql
}
return sql + ";"
}
func isMySQLHexLiteral(s string) bool {
if len(s) < 3 || !(strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X")) {
return false
}
for i := 2; i < len(s); i++ {
c := s[i]
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return false
}
}
return true
}
func formatSQLValue(dbType string, v interface{}) string {
if v == nil {
return "NULL"
}
switch val := v.(type) {
case bool:
if val {
return "1"
}
return "0"
case int:
return strconv.Itoa(val)
case int8, int16, int32, int64:
return fmt.Sprintf("%d", val)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", val)
case float32:
f := float64(val)
if math.IsNaN(f) || math.IsInf(f, 0) {
return "NULL"
}
return strconv.FormatFloat(f, 'f', -1, 32)
case float64:
if math.IsNaN(val) || math.IsInf(val, 0) {
return "NULL"
}
return strconv.FormatFloat(val, 'f', -1, 64)
case time.Time:
return "'" + val.Format("2006-01-02 15:04:05") + "'"
case string:
if strings.ToLower(strings.TrimSpace(dbType)) == "mysql" && isMySQLHexLiteral(val) {
return val
}
escaped := strings.ReplaceAll(val, "'", "''")
return "'" + escaped + "'"
default:
escaped := strings.ReplaceAll(fmt.Sprintf("%v", v), "'", "''")
return "'" + escaped + "'"
}
}
func dumpTableSQL(w *bufio.Writer, dbInst db.Database, config connection.ConnectionConfig, dbName, tableName string, includeData bool) error {
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
if _, err := w.WriteString("\n-- ----------------------------\n"); err != nil {
return err
}
if _, err := w.WriteString(fmt.Sprintf("-- Table: %s\n", qualifyTable(schemaName, pureTableName))); err != nil {
return err
}
if _, err := w.WriteString("-- ----------------------------\n\n"); err != nil {
return err
}
createSQL, err := dbInst.GetCreateStatement(schemaName, pureTableName)
if err != nil {
return err
}
if _, err := w.WriteString(ensureSQLTerminator(createSQL)); err != nil {
return err
}
if _, err := w.WriteString("\n\n"); err != nil {
return err
}
if !includeData {
return nil
}
qualified := qualifyTable(schemaName, pureTableName)
selectSQL := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.Type, qualified))
data, columns, err := dbInst.Query(selectSQL)
if err != nil {
return err
}
if len(data) == 0 {
if _, err := w.WriteString("-- (0 rows)\n"); err != nil {
return err
}
return nil
}
quotedCols := make([]string, 0, len(columns))
for _, c := range columns {
quotedCols = append(quotedCols, quoteIdentByType(config.Type, c))
}
quotedTable := quoteQualifiedIdentByType(config.Type, qualified)
for _, row := range data {
values := make([]string, 0, len(columns))
for _, c := range columns {
values = append(values, formatSQLValue(config.Type, row[c]))
}
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", quotedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", "))); err != nil {
return err
}
}
return nil
}
// ExportData exports provided data to a file
func (a *App) ExportData(data []map[string]interface{}, columns []string, defaultName string, format string) connection.QueryResult {
if defaultName == "" {
@@ -359,33 +613,101 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
if err := writeRowsToFile(f, data, columns, format); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
// ExportQuery exports by executing the provided SELECT query on backend side.
// This avoids frontend IPC payload limits when exporting very large/long-text columns (e.g. base64).
func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, query string, defaultName string, format string) connection.QueryResult {
query = strings.TrimSpace(query)
if query == "" {
return connection.QueryResult{Success: false, Message: "query required"}
}
if defaultName == "" {
defaultName = "export"
}
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: "Export Query Result",
DefaultFilename: fmt.Sprintf("%s.%s", defaultName, strings.ToLower(format)),
})
if err != nil || filename == "" {
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
query = sanitizeSQLForPgLike(runConfig.Type, query)
lowerQuery := strings.ToLower(strings.TrimSpace(query))
if !(strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "with")) {
return connection.QueryResult{Success: false, Message: "Only SELECT/WITH queries are supported"}
}
data, columns, err := dbInst.Query(query)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
f, err := os.Create(filename)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
defer f.Close()
if err := writeRowsToFile(f, data, columns, format); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
}
func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string, format string) error {
format = strings.ToLower(strings.TrimSpace(format))
if f == nil {
return fmt.Errorf("file required")
}
format = strings.ToLower(format)
var csvWriter *csv.Writer
var jsonEncoder *json.Encoder
var isJsonFirstRow = true
isJsonFirstRow := true
switch format {
case "csv", "xlsx":
f.Write([]byte{0xEF, 0xBB, 0xBF})
if _, err := f.Write([]byte{0xEF, 0xBB, 0xBF}); err != nil {
return err
}
csvWriter = csv.NewWriter(f)
defer csvWriter.Flush()
if err := csvWriter.Write(columns); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
return err
}
case "json":
f.WriteString("[\n")
if _, err := f.WriteString("[\n"); err != nil {
return err
}
jsonEncoder = json.NewEncoder(f)
jsonEncoder.SetIndent(" ", " ")
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | ")); err != nil {
return err
}
seps := make([]string, len(columns))
for i := range seps {
seps[i] = "---"
}
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | ")); err != nil {
return err
}
default:
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
return fmt.Errorf("unsupported format: %s", format)
}
for _, rowMap := range data {
@@ -394,37 +716,51 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
val := rowMap[col]
if val == nil {
record[i] = "NULL"
} else {
s := fmt.Sprintf("%v", val)
if format == "md" {
s = strings.ReplaceAll(s, "|", "\\|")
s = strings.ReplaceAll(s, "\n", "<br>")
}
record[i] = s
continue
}
s := fmt.Sprintf("%v", val)
if format == "md" {
s = strings.ReplaceAll(s, "|", "\\|")
s = strings.ReplaceAll(s, "\n", "<br>")
}
record[i] = s
}
switch format {
case "csv", "xlsx":
if err := csvWriter.Write(record); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
return err
}
case "json":
if !isJsonFirstRow {
f.WriteString(",\n")
if _, err := f.WriteString(",\n"); err != nil {
return err
}
}
if err := jsonEncoder.Encode(rowMap); err != nil {
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
return err
}
isJsonFirstRow = false
case "md":
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | ")); err != nil {
return err
}
}
}
if format == "csv" || format == "xlsx" {
csvWriter.Flush()
if err := csvWriter.Error(); err != nil {
return err
}
}
if format == "json" {
f.WriteString("\n]")
if _, err := f.WriteString("\n]"); err != nil {
return err
}
}
return connection.QueryResult{Success: true, Message: "Export successful"}
return nil
}

View File

@@ -0,0 +1,236 @@
package app
import (
"strings"
"unicode"
)
func sanitizeSQLForPgLike(dbType string, query string) string {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "postgres", "kingbase":
// 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。
// 这里做有限次数的迭代,直到输出不再变化。
out := query
for i := 0; i < 3; i++ {
fixed := fixBrokenDoubleDoubleQuotedIdent(out)
if fixed == out {
break
}
out = fixed
}
return out
default:
return query
}
}
// fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like:
//
// SELECT * FROM ""schema"".""table""
//
// which can be produced when a quoted identifier gets wrapped by quotes again.
//
// It is intentionally conservative:
// - only runs outside strings/comments/dollar-quoted blocks
// - does not touch valid escaped-quote sequences inside quoted identifiers (e.g. "a""b")
func fixBrokenDoubleDoubleQuotedIdent(query string) string {
if !strings.Contains(query, `""`) {
return query
}
var b strings.Builder
b.Grow(len(query))
inSingle := false
inDoubleIdent := false
inLineComment := false
inBlockComment := false
dollarTag := ""
for i := 0; i < len(query); i++ {
ch := query[i]
next := byte(0)
if i+1 < len(query) {
next = query[i+1]
}
if inLineComment {
b.WriteByte(ch)
if ch == '\n' {
inLineComment = false
}
continue
}
if inBlockComment {
b.WriteByte(ch)
if ch == '*' && next == '/' {
b.WriteByte('/')
i++
inBlockComment = false
}
continue
}
if dollarTag != "" {
if strings.HasPrefix(query[i:], dollarTag) {
b.WriteString(dollarTag)
i += len(dollarTag) - 1
dollarTag = ""
continue
}
b.WriteByte(ch)
continue
}
if inSingle {
b.WriteByte(ch)
if ch == '\'' {
// escaped single quote
if next == '\'' {
b.WriteByte('\'')
i++
continue
}
inSingle = false
}
continue
}
if inDoubleIdent {
b.WriteByte(ch)
if ch == '"' {
// escaped quote inside identifier
if next == '"' {
b.WriteByte('"')
i++
continue
}
inDoubleIdent = false
}
continue
}
// --- Outside of all string/comment blocks ---
if ch == '-' && next == '-' {
b.WriteByte(ch)
b.WriteByte('-')
i++
inLineComment = true
continue
}
if ch == '/' && next == '*' {
b.WriteByte(ch)
b.WriteByte('*')
i++
inBlockComment = true
continue
}
if ch == '\'' {
b.WriteByte(ch)
inSingle = true
continue
}
if ch == '$' {
if tag := parseDollarTag(query[i:]); tag != "" {
b.WriteString(tag)
i += len(tag) - 1
dollarTag = tag
continue
}
}
if ch == '"' {
// Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
// Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side).
if next == '"' {
if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok {
b.WriteString(replacement)
i = advance - 1
continue
}
}
b.WriteByte(ch)
inDoubleIdent = true
continue
}
b.WriteByte(ch)
}
return b.String()
}
func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) {
// start points at the first quote of a broken identifier, usually like:
// ""ident"" / ""ident""" / """"ident""""
if start < 0 || start+1 >= len(query) {
return "", 0, false
}
if query[start] != '"' || query[start+1] != '"' {
return "", 0, false
}
if start > 0 && query[start-1] == '"' {
return "", 0, false
}
runLen := 0
for start+runLen < len(query) && query[start+runLen] == '"' {
runLen++
}
if runLen < 2 || runLen%2 == 1 {
// Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes.
return "", 0, false
}
contentStart := start + runLen
j := contentStart
for j < len(query) {
if query[j] == '"' {
endRunLen := 0
for j+endRunLen < len(query) && query[j+endRunLen] == '"' {
endRunLen++
}
if endRunLen >= 2 {
content := strings.TrimSpace(query[contentStart:j])
if looksLikeIdentifierContent(content) {
return `"` + content + `"`, j + endRunLen, true
}
return "", 0, false
}
}
// Fast abort: identifier-like content should not span lines.
if query[j] == '\n' || query[j] == '\r' {
break
}
j++
}
return "", 0, false
}
func looksLikeIdentifierContent(s string) bool {
if strings.TrimSpace(s) == "" {
return false
}
for _, r := range s {
if r == '_' || r == '$' || r == '-' || unicode.IsLetter(r) || unicode.IsDigit(r) {
continue
}
return false
}
return true
}
func parseDollarTag(s string) string {
// Match: $tag$ where tag is [A-Za-z0-9_]* (can be empty => $$)
if len(s) < 2 || s[0] != '$' {
return ""
}
for i := 1; i < len(s); i++ {
c := s[i]
if c == '$' {
return s[:i+1]
}
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return ""
}
}
return ""
}

View File

@@ -0,0 +1,55 @@
package app
import "testing"
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes(t *testing.T) {
in := `SELECT * FROM ""ldf_server"".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithExtraQuotes(t *testing.T) {
in := `SELECT * FROM ""ldf_server""".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithQuadQuotes(t *testing.T) {
in := `SELECT * FROM """"ldf_server"""".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_DoesNotTouchEscapedQuotesInsideIdentifier(t *testing.T) {
in := `SELECT "a""b" FROM "t""x"`
out := sanitizeSQLForPgLike("postgres", in)
if out != in {
t.Fatalf("should keep valid escaped quotes inside identifier:\nIN: %s\nOUT: %s", in, out)
}
}
func TestSanitizeSQLForPgLike_DoesNotTouchDollarQuotedStrings(t *testing.T) {
in := "SELECT $$\"\"ldf_server\"\"$$, \"\"ldf_server\"\""
out := sanitizeSQLForPgLike("postgres", in)
want := "SELECT $$\"\"ldf_server\"\"$$, \"ldf_server\""
if out != want {
t.Fatalf("unexpected sanitize output for dollar quoted string:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
in := `SELECT * FROM ""ldf_server""`
out := sanitizeSQLForPgLike("mysql", in)
if out != in {
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
}
}

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"strings"
@@ -57,6 +58,20 @@ func (c *CustomDB) Ping() error {
return c.conn.PingContext(ctx)
}
func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if c.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := c.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, error) {
if c.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -67,33 +82,18 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) {
if c.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := c.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (c *CustomDB) Exec(query string) (int64, error) {

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
@@ -10,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -19,6 +21,7 @@ import (
type DamengDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
@@ -26,16 +29,6 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
// or dm://user:password@host:port
address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
if config.UseSSH {
// SSH logic similar to others, assumes port forwarding
_, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// DM driver likely uses standard net.Dial, so we might need a local listener
// or assume port forwarding is handled externally or implicitly via "tcp" override if driver allows.
// Similar to Oracle, we skip complex custom dialer injection for now.
}
}
escapedPassword := url.PathEscape(config.Password)
q := url.Values{}
if config.Database != "" {
@@ -55,7 +48,42 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
}
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
dsn := d.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("达梦数据库使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
d.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = d.getDSN(localConfig)
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = d.getDSN(config)
}
db, err := sql.Open("dm", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -69,6 +97,15 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
}
func (d *DamengDB) Close() error {
// Close SSH forwarder first if exists
if d.forwarder != nil {
if err := d.forwarder.Close(); err != nil {
logger.Warnf("关闭达梦数据库 SSH 端口转发失败:%v", err)
}
d.forwarder = nil
}
// Then close database connection
if d.conn != nil {
return d.conn.Close()
}
@@ -88,6 +125,20 @@ func (d *DamengDB) Ping() error {
return d.conn.PingContext(ctx)
}
func (d *DamengDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if d.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := d.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, error) {
if d.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -98,33 +149,18 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) {
if d.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := d.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (d *DamengDB) Exec(query string) (int64, error) {

View File

@@ -1,12 +1,16 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -16,6 +20,7 @@ import (
type KingbaseDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func quoteConnValue(v string) string {
@@ -57,20 +62,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
address := config.Host
port := config.Port
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// Kingbase/Postgres lib/pq allows custom dialer via "host" if using unix socket,
// but for custom network it's harder.
// Ideally we use a local forwarder.
// For now, we assume standard TCP or handle SSH externally.
// If we implement the net.Dial override for "kingbase" driver (which might use lib/pq internally),
// we might need to check if it supports "cloudsql" style or similar custom dialers.
// Similar to others, skipping SSH deep integration here for now.
_ = netName
}
}
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
quoteConnValue(address),
@@ -85,7 +76,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
dsn := k.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
k.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = k.getDSN(localConfig)
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = k.getDSN(config)
}
// Open using "kingbase" driver
db, err := sql.Open("kingbase", dsn)
if err != nil {
@@ -100,6 +126,15 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
}
func (k *KingbaseDB) Close() error {
// Close SSH forwarder first if exists
if k.forwarder != nil {
if err := k.forwarder.Close(); err != nil {
logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
}
k.forwarder = nil
}
// Then close database connection
if k.conn != nil {
return k.conn.Close()
}
@@ -119,6 +154,20 @@ func (k *KingbaseDB) Ping() error {
return k.conn.PingContext(ctx)
}
func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if k.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := k.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
if k.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -129,33 +178,18 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
if k.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := k.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (k *KingbaseDB) Exec(query string) (int64, error) {
@@ -223,15 +257,84 @@ func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error
}
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
schema := "public"
if dbName != "" {
schema = dbName
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = '%s' AND table_name = '%s'
ORDER BY ordinal_position`, schema, tableName)
// 如果仍然没有 schema,使用 current_schema()
// 这样可以自动匹配当前连接的 search_path
if schema == "" {
return k.getColumnsWithCurrentSchema(table)
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
// 移除前后的双引号(如果存在)
s = strings.Trim(s, "\"")
// 转义单引号
return strings.ReplaceAll(s, "'", "''")
}
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = '%s' AND table_name = '%s'
ORDER BY ordinal_position`, esc(schema), esc(table))
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
}
if row["column_default"] != nil {
def := fmt.Sprintf("%v", row["column_default"])
col.Default = &def
}
columns = append(columns, col)
}
return columns, nil
}
// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 使用 current_schema() 获取当前schema
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = '%s'
ORDER BY ordinal_position`, esc(table))
data, _, err := k.Query(query)
if err != nil {
@@ -257,32 +360,76 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
}
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
// Postgres/Kingbase index query
query := fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = '%s'
`, tableName, "public") // Default to public if dbName (schema) not clear.
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
if dbName != "" {
// Update query to use dbName as schema
query = strings.Replace(query, "'public'", fmt.Sprintf("'%s'", dbName), 1)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = '%s'
`, esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = current_schema()
`, esc(table))
}
data, _, err := k.Query(query)
@@ -311,27 +458,67 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
}
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
schema := "public"
if dbName != "" {
schema = dbName
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
query := fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
tableName, schema)
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {
@@ -353,9 +540,43 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
}
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
query := fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s'`, tableName)
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果指定了schema,也加上schema条件
var query string
if schema != "" {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {
@@ -380,14 +601,13 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
}
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
schema := "public"
if dbName != "" {
schema = dbName
}
query := fmt.Sprintf(`SELECT table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema = '%s'`, schema)
// dbName 在本项目语义里是“数据库”schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
query := `
SELECT table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name, ordinal_position`
data, _, err := k.Query(query)
if err != nil {
@@ -396,8 +616,14 @@ func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinition
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
schema := fmt.Sprintf("%v", row["table_schema"])
table := fmt.Sprintf("%v", row["table_name"])
tableName := table
if strings.TrimSpace(schema) != "" {
tableName = fmt.Sprintf("%s.%s", schema, table)
}
col := connection.ColumnDefinitionWithTable{
TableName: fmt.Sprintf("%v", row["table_name"]),
TableName: tableName,
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
}

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"strings"
@@ -76,6 +77,20 @@ func (m *MySQLDB) Ping() error {
return m.conn.PingContext(ctx)
}
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -86,33 +101,18 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := m.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (m *MySQLDB) Exec(query string) (int64, error) {

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
@@ -10,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -19,6 +21,7 @@ import (
type OracleDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
@@ -28,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
database = config.User // Default to user service/schema if empty?
}
if config.UseSSH {
_, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// Oracle driver might not support custom dialer via DSN easily without extra config
// But go-ora v2 supports some advanced options.
// For simplicity, we assume standard TCP or we might need a workaround for SSH.
// go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open.
// But for now, let's just use the address.
// SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...).
// We might need to forward a local port if using SSH.
// Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...",
// we need to see if go-ora supports custom networks.
// Checking go-ora docs (simulated): It supports "unix" and "tcp".
// We might need to map the custom network to a local proxy.
// For now, we will assume direct connection or handle SSH separately later.
// We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial.
// Note: go-ora connection string: oracle://user:pass@host:port/service
// It parses host/port. It doesn't easily take a custom "network" parameter in URL.
// We will proceed with standard TCP string.
}
}
u := &url.URL{
Scheme: "oracle",
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
@@ -61,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
dsn := o.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
o.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = o.getDSN(localConfig)
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = o.getDSN(config)
}
db, err := sql.Open("oracle", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -75,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
}
func (o *OracleDB) Close() error {
// Close SSH forwarder first if exists
if o.forwarder != nil {
if err := o.forwarder.Close(); err != nil {
logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err)
}
o.forwarder = nil
}
// Then close database connection
if o.conn != nil {
return o.conn.Close()
}
@@ -94,6 +119,20 @@ func (o *OracleDB) Ping() error {
return o.conn.PingContext(ctx)
}
func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if o.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := o.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, error) {
if o.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -104,33 +143,18 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
if o.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := o.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (o *OracleDB) Exec(query string) (int64, error) {

View File

@@ -1,24 +1,31 @@
package db
import (
"context"
"database/sql"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/lib/pq"
)
type PostgresDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
// postgres://user:password@host:port/dbname?sslmode=disable
dbname := config.Database
@@ -41,7 +48,42 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
}
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
dsn := p.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
p.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false // Disable SSH flag for DSN generation
dsn = p.getDSN(localConfig)
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = p.getDSN(config)
}
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -56,7 +98,17 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
return nil
}
func (p *PostgresDB) Close() error {
// Close SSH forwarder first if exists
if p.forwarder != nil {
if err := p.forwarder.Close(); err != nil {
logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err)
}
p.forwarder = nil
}
// Then close database connection
if p.conn != nil {
return p.conn.Close()
}
@@ -76,6 +128,20 @@ func (p *PostgresDB) Ping() error {
return p.conn.PingContext(ctx)
}
func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if p.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := p.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
if p.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -86,33 +152,18 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
if p.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := p.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (p *PostgresDB) Exec(query string) (int64, error) {
@@ -167,21 +218,306 @@ func (p *PostgresDB) GetCreateStatement(dbName, tableName string) (string, error
}
func (p *PostgresDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return []connection.ColumnDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
col_description(a.attrelid, a.attnum) AS comment,
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_attribute a ON a.attrelid = c.oid
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
LEFT JOIN (
SELECT i.indrelid, a3.attname
FROM pg_index i
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
WHERE i.indisprimary
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
WHERE c.relkind IN ('r', 'p')
AND n.nspname = '%s'
AND c.relname = '%s'
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum`, esc(schema), esc(table))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
Key: fmt.Sprintf("%v", row["column_key"]),
Extra: "",
Comment: "",
}
if v, ok := row["comment"]; ok && v != nil {
col.Comment = fmt.Sprintf("%v", v)
}
if v, ok := row["column_default"]; ok && v != nil {
def := fmt.Sprintf("%v", v)
col.Default = &def
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
col.Extra = "auto_increment"
}
}
columns = append(columns, col)
}
return columns, nil
}
func (p *PostgresDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return []connection.IndexDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT
i.relname AS index_name,
a.attname AS column_name,
ix.indisunique AS is_unique,
x.ordinality AS seq_in_index,
am.amname AS index_type
FROM pg_class t
JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN pg_index ix ON t.oid = ix.indrelid
JOIN pg_class i ON i.oid = ix.indexrelid
JOIN pg_am am ON i.relam = am.oid
JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
WHERE t.relkind IN ('r', 'p')
AND t.relname = '%s'
AND n.nspname = '%s'
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
parseBool := func(v interface{}) bool {
switch val := v.(type) {
case bool:
return val
case string:
s := strings.ToLower(strings.TrimSpace(val))
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
default:
s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
}
}
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
// best effort
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
var indexes []connection.IndexDefinition
for _, row := range data {
isUnique := false
if v, ok := row["is_unique"]; ok && v != nil {
isUnique = parseBool(v)
}
nonUnique := 1
if isUnique {
nonUnique = 0
}
seq := 0
if v, ok := row["seq_in_index"]; ok && v != nil {
seq = parseInt(v)
}
indexType := ""
if v, ok := row["index_type"]; ok && v != nil {
indexType = strings.ToUpper(fmt.Sprintf("%v", v))
}
if indexType == "" {
indexType = "BTREE"
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["index_name"]),
ColumnName: fmt.Sprintf("%v", row["column_name"]),
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: indexType,
}
indexes = append(indexes, idx)
}
return indexes, nil
}
func (p *PostgresDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT
tc.constraint_name AS constraint_name,
kcu.column_name AS column_name,
ccu.table_schema AS foreign_table_schema,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name = '%s'
AND tc.table_schema = '%s'
ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
refSchema := ""
if v, ok := row["foreign_table_schema"]; ok && v != nil {
refSchema = fmt.Sprintf("%v", v)
}
refTable := fmt.Sprintf("%v", row["foreign_table_name"])
refTableName := refTable
if strings.TrimSpace(refSchema) != "" {
refTableName = fmt.Sprintf("%s.%s", refSchema, refTable)
}
fk := connection.ForeignKeyDefinition{
Name: fmt.Sprintf("%v", row["constraint_name"]),
ColumnName: fmt.Sprintf("%v", row["column_name"]),
RefTableName: refTableName,
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
}
fks = append(fks, fk)
}
return fks, nil
}
func (p *PostgresDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
schema := strings.TrimSpace(dbName)
if schema == "" {
schema = "public"
}
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
query := fmt.Sprintf(`
SELECT trigger_name, action_timing, event_manipulation, action_statement
FROM information_schema.triggers
WHERE event_object_table = '%s'
AND event_object_schema = '%s'
ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema))
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
trig := connection.TriggerDefinition{
Name: fmt.Sprintf("%v", row["trigger_name"]),
Timing: fmt.Sprintf("%v", row["action_timing"]),
Event: fmt.Sprintf("%v", row["event_manipulation"]),
Statement: fmt.Sprintf("%v", row["action_statement"]),
}
triggers = append(triggers, trig)
}
return triggers, nil
}
func (p *PostgresDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return []connection.ColumnDefinitionWithTable{}, nil
query := `
SELECT table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name, ordinal_position`
data, _, err := p.Query(query)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
schema := fmt.Sprintf("%v", row["table_schema"])
table := fmt.Sprintf("%v", row["table_name"])
tableName := table
if strings.TrimSpace(schema) != "" {
tableName = fmt.Sprintf("%s.%s", schema, table)
}
col := connection.ColumnDefinitionWithTable{
TableName: tableName,
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
}
cols = append(cols, col)
}
return cols, nil
}

View File

@@ -2,6 +2,8 @@ package db
import (
"encoding/hex"
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
@@ -9,13 +11,17 @@ import (
// normalizeQueryValue normalizes driver-returned values for UI/JSON transport.
// 当前主要处理 []byte如果是可读文本则转为 string否则转为十六进制字符串避免前端出现“空白值”。
func normalizeQueryValue(v interface{}) interface{} {
return normalizeQueryValueWithDBType(v, "")
}
func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} {
if b, ok := v.([]byte); ok {
return bytesToReadableString(b)
return bytesToDisplayValue(b, databaseTypeName)
}
return v
}
func bytesToReadableString(b []byte) interface{} {
func bytesToDisplayValue(b []byte, databaseTypeName string) interface{} {
if b == nil {
return nil
}
@@ -23,6 +29,18 @@ func bytesToReadableString(b []byte) interface{} {
return ""
}
dbType := strings.ToUpper(strings.TrimSpace(databaseTypeName))
if isBitLikeDBType(dbType) {
if u, ok := bytesToUint64(b); ok {
// JS number precision is limited; keep large bitmasks as string.
const maxSafeInteger = 9007199254740991 // 2^53 - 1
if u <= maxSafeInteger {
return int64(u)
}
return fmt.Sprintf("%d", u)
}
}
if utf8.Valid(b) {
s := string(b)
if isMostlyPrintable(s) {
@@ -30,9 +48,47 @@ func bytesToReadableString(b []byte) interface{} {
}
}
// Fallback: some drivers return BIT(1) as []byte{0} / []byte{1} without type info.
if dbType == "" && len(b) == 1 && (b[0] == 0 || b[0] == 1) {
return int64(b[0])
}
return bytesToReadableString(b)
}
func bytesToReadableString(b []byte) interface{} {
if b == nil {
return nil
}
if len(b) == 0 {
return ""
}
return "0x" + hex.EncodeToString(b)
}
func isBitLikeDBType(typeName string) bool {
if typeName == "" {
return false
}
switch typeName {
case "BIT", "VARBIT":
return true
default:
}
return strings.HasPrefix(typeName, "BIT")
}
func bytesToUint64(b []byte) (uint64, bool) {
if len(b) == 0 || len(b) > 8 {
return 0, false
}
var u uint64
for _, v := range b {
u = (u << 8) | uint64(v)
}
return u, true
}
func isMostlyPrintable(s string) bool {
if s == "" {
return true

View File

@@ -0,0 +1,44 @@
package db
import "testing"
func TestNormalizeQueryValueWithDBType_BitBytes(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte{0x00}, "BIT")
if v != int64(0) {
t.Fatalf("BIT 0x00 期望为 0实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x01}, "bit")
if v != int64(1) {
t.Fatalf("BIT 0x01 期望为 1实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x01, 0x02}, "BIT VARYING")
if v != int64(258) {
t.Fatalf("BIT 0x0102 期望为 258实际=%v(%T)", v, v)
}
}
func TestNormalizeQueryValueWithDBType_BitLargeAsString(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "BIT")
if s, ok := v.(string); !ok || s != "18446744073709551615" {
t.Fatalf("BIT 0xffffffffffffffff 期望为 string(18446744073709551615),实际=%v(%T)", v, v)
}
}
func TestNormalizeQueryValueWithDBType_ByteFallbacks(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte("abc"), "")
if v != "abc" {
t.Fatalf("文本 []byte 期望返回 string实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x00}, "")
if v != int64(0) {
t.Fatalf("未知类型 0x00 期望返回 0实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0xff}, "")
if v != "0xff" {
t.Fatalf("未知类型 0xff 期望返回 0xff实际=%v(%T)", v, v)
}
}

46
internal/db/scan_rows.go Normal file
View File

@@ -0,0 +1,46 @@
package db
import (
"database/sql"
)
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
colTypes, err := rows.ColumnTypes()
if err != nil || len(colTypes) != len(columns) {
colTypes = nil
}
resultData := make([]map[string]interface{}, 0)
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{}, len(columns))
for i, col := range columns {
dbTypeName := ""
if colTypes != nil && i < len(colTypes) && colTypes[i] != nil {
dbTypeName = colTypes[i].DatabaseTypeName()
}
entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName)
}
resultData = append(resultData, entry)
}
if err := rows.Err(); err != nil {
return resultData, columns, err
}
return resultData, columns, nil
}

View File

@@ -1,8 +1,10 @@
package db
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
@@ -52,6 +54,20 @@ func (s *SQLiteDB) Ping() error {
return s.conn.PingContext(ctx)
}
func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := s.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
@@ -62,33 +78,18 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
columns, err := rows.Columns()
func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) {
if s.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := s.conn.ExecContext(ctx, query)
if err != nil {
return nil, nil, err
return 0, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return res.RowsAffected()
}
func (s *SQLiteDB) Exec(query string) (int64, error) {
@@ -137,21 +138,336 @@ func (s *SQLiteDB) GetCreateStatement(dbName, tableName string) (string, error)
}
func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return []connection.ColumnDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
// cid, name, type, notnull, dflt_value, pk
data, _, err := s.Query(fmt.Sprintf("PRAGMA table_info('%s')", esc(table)))
if err != nil {
return nil, err
}
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
getStr := func(row map[string]interface{}, key string) string {
if v, ok := row[key]; ok && v != nil {
return fmt.Sprintf("%v", v)
}
if v, ok := row[strings.ToUpper(key)]; ok && v != nil {
return fmt.Sprintf("%v", v)
}
return ""
}
var columns []connection.ColumnDefinition
for _, row := range data {
notnull := 0
if v, ok := row["notnull"]; ok && v != nil {
notnull = parseInt(v)
} else if v, ok := row["NOTNULL"]; ok && v != nil {
notnull = parseInt(v)
}
pk := 0
if v, ok := row["pk"]; ok && v != nil {
pk = parseInt(v)
} else if v, ok := row["PK"]; ok && v != nil {
pk = parseInt(v)
}
nullable := "YES"
if notnull == 1 {
nullable = "NO"
}
key := ""
if pk == 1 {
key = "PRI"
}
col := connection.ColumnDefinition{
Name: getStr(row, "name"),
Type: getStr(row, "type"),
Nullable: nullable,
Key: key,
Extra: "",
Comment: "",
}
if v, ok := row["dflt_value"]; ok && v != nil {
def := fmt.Sprintf("%v", v)
col.Default = &def
} else if v, ok := row["DFLT_VALUE"]; ok && v != nil {
def := fmt.Sprintf("%v", v)
col.Default = &def
}
columns = append(columns, col)
}
return columns, nil
}
func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return []connection.IndexDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
data, _, err := s.Query(fmt.Sprintf("PRAGMA index_list('%s')", esc(table)))
if err != nil {
return nil, err
}
var indexes []connection.IndexDefinition
for _, row := range data {
indexName := ""
if v, ok := row["name"]; ok && v != nil {
indexName = fmt.Sprintf("%v", v)
} else if v, ok := row["NAME"]; ok && v != nil {
indexName = fmt.Sprintf("%v", v)
}
if strings.TrimSpace(indexName) == "" {
continue
}
unique := 0
if v, ok := row["unique"]; ok && v != nil {
unique = parseInt(v)
} else if v, ok := row["UNIQUE"]; ok && v != nil {
unique = parseInt(v)
}
nonUnique := 1
if unique == 1 {
nonUnique = 0
}
cols, _, err := s.Query(fmt.Sprintf("PRAGMA index_info('%s')", esc(indexName)))
if err != nil {
// skip broken index
continue
}
for _, c := range cols {
colName := ""
if v, ok := c["name"]; ok && v != nil {
colName = fmt.Sprintf("%v", v)
} else if v, ok := c["NAME"]; ok && v != nil {
colName = fmt.Sprintf("%v", v)
}
if strings.TrimSpace(colName) == "" {
continue
}
seq := 0
if v, ok := c["seqno"]; ok && v != nil {
seq = parseInt(v) + 1
} else if v, ok := c["SEQNO"]; ok && v != nil {
seq = parseInt(v) + 1
}
indexes = append(indexes, connection.IndexDefinition{
Name: indexName,
ColumnName: colName,
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: "BTREE",
})
}
}
return indexes, nil
}
func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
data, _, err := s.Query(fmt.Sprintf("PRAGMA foreign_key_list('%s')", esc(table)))
if err != nil {
return nil, err
}
parseInt := func(v interface{}) int {
switch val := v.(type) {
case int:
return val
case int64:
return int(val)
case float64:
return int(val)
case string:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
return n
default:
var n int
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
return n
}
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
id := 0
if v, ok := row["id"]; ok && v != nil {
id = parseInt(v)
} else if v, ok := row["ID"]; ok && v != nil {
id = parseInt(v)
}
refTable := ""
if v, ok := row["table"]; ok && v != nil {
refTable = fmt.Sprintf("%v", v)
} else if v, ok := row["TABLE"]; ok && v != nil {
refTable = fmt.Sprintf("%v", v)
}
fromCol := ""
if v, ok := row["from"]; ok && v != nil {
fromCol = fmt.Sprintf("%v", v)
} else if v, ok := row["FROM"]; ok && v != nil {
fromCol = fmt.Sprintf("%v", v)
}
toCol := ""
if v, ok := row["to"]; ok && v != nil {
toCol = fmt.Sprintf("%v", v)
} else if v, ok := row["TO"]; ok && v != nil {
toCol = fmt.Sprintf("%v", v)
}
name := fmt.Sprintf("fk_%s_%d", table, id)
fks = append(fks, connection.ForeignKeyDefinition{
Name: name,
ColumnName: fromCol,
RefTableName: refTable,
RefColumnName: toCol,
ConstraintName: name,
})
}
return fks, nil
}
func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
data, _, err := s.Query(fmt.Sprintf("SELECT name AS trigger_name, sql AS statement FROM sqlite_master WHERE type='trigger' AND tbl_name='%s' ORDER BY name", esc(table)))
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
name := fmt.Sprintf("%v", row["trigger_name"])
stmt := ""
if v, ok := row["statement"]; ok && v != nil {
stmt = fmt.Sprintf("%v", v)
}
upper := strings.ToUpper(stmt)
timing := ""
switch {
case strings.Contains(upper, " BEFORE "):
timing = "BEFORE"
case strings.Contains(upper, " AFTER "):
timing = "AFTER"
case strings.Contains(upper, " INSTEAD OF "):
timing = "INSTEAD OF"
}
event := ""
switch {
case strings.Contains(upper, " INSERT "):
event = "INSERT"
case strings.Contains(upper, " UPDATE "):
event = "UPDATE"
case strings.Contains(upper, " DELETE "):
event = "DELETE"
}
triggers = append(triggers, connection.TriggerDefinition{
Name: name,
Timing: timing,
Event: event,
Statement: stmt,
})
}
return triggers, nil
}
func (s *SQLiteDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return []connection.ColumnDefinitionWithTable{}, nil
tables, err := s.GetTables(dbName)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, table := range tables {
// Skip internal tables
if strings.HasPrefix(strings.ToLower(table), "sqlite_") {
continue
}
columns, err := s.GetColumns("", table)
if err != nil {
continue
}
for _, col := range columns {
cols = append(cols, connection.ColumnDefinitionWithTable{
TableName: table,
Name: col.Name,
Type: col.Type,
})
}
}
return cols, nil
}

View File

@@ -3,8 +3,10 @@ package ssh
import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"GoNavi-Wails/internal/connection"
@@ -110,3 +112,264 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
return netName, nil
}
// sshClientCache stores SSH clients to avoid creating multiple connections
var (
sshClientCache = make(map[string]*ssh.Client)
sshClientCacheMu sync.RWMutex
localForwarders = make(map[string]*LocalForwarder)
forwarderMu sync.RWMutex
)
// LocalForwarder represents a local port forwarder through SSH
type LocalForwarder struct {
LocalAddr string
RemoteAddr string
SSHClient *ssh.Client
listener net.Listener
closeChan chan struct{}
closeOnce sync.Once // 防止重复关闭
closed bool // 关闭状态标记
closedMu sync.RWMutex
}
// NewLocalForwarder creates a new local port forwarder
// It listens on a random local port and forwards all connections through SSH tunnel
func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
client, err := GetOrCreateSSHClient(sshConfig)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
// Listen on localhost with a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("创建本地监听器失败:%w", err)
}
localAddr := listener.Addr().String()
remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)
forwarder := &LocalForwarder{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
SSHClient: client,
listener: listener,
closeChan: make(chan struct{}),
}
// Start forwarding in background
go forwarder.forward()
logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr)
return forwarder, nil
}
// forward handles the port forwarding
func (f *LocalForwarder) forward() {
for {
localConn, err := f.listener.Accept()
if err != nil {
// Check if we're shutting down
select {
case <-f.closeChan:
return
default:
logger.Warnf("接受本地连接失败:%v", err)
// listener可能已关闭,退出循环
return
}
}
go f.handleConnection(localConn)
}
}
// handleConnection handles a single connection
func (f *LocalForwarder) handleConnection(localConn net.Conn) {
defer localConn.Close()
// Connect to remote through SSH with timeout
remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr)
if err != nil {
logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err)
return
}
defer remoteConn.Close()
// Bidirectional copy with error channel
errc := make(chan error, 2)
// Copy from local to remote
go func() {
_, err := io.Copy(remoteConn, localConn)
if err != nil {
logger.Warnf("本地->远程数据复制错误:%v", err)
}
errc <- err
}()
// Copy from remote to local
go func() {
_, err := io.Copy(localConn, remoteConn)
if err != nil {
logger.Warnf("远程->本地数据复制错误:%v", err)
}
errc <- err
}()
// Wait for BOTH goroutines to complete
<-errc
<-errc
}
// Close closes the forwarder (thread-safe, can be called multiple times)
func (f *LocalForwarder) Close() error {
var err error
f.closeOnce.Do(func() {
f.closedMu.Lock()
f.closed = true
f.closedMu.Unlock()
close(f.closeChan)
err = f.listener.Close()
if err != nil {
logger.Warnf("关闭端口转发监听器失败:%v", err)
}
})
return err
}
// IsClosed returns whether the forwarder is closed
func (f *LocalForwarder) IsClosed() bool {
f.closedMu.RLock()
defer f.closedMu.RUnlock()
return f.closed
}
// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one
func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
key := fmt.Sprintf("%s:%d:%s->%s:%d",
sshConfig.Host, sshConfig.Port, sshConfig.User,
remoteHost, remotePort)
forwarderMu.RLock()
forwarder, exists := localForwarders[key]
forwarderMu.RUnlock()
// Check if exists and is still valid
if exists && forwarder != nil && !forwarder.IsClosed() {
logger.Infof("复用已有端口转发:%s", key)
return forwarder, nil
}
// Remove stale forwarder from cache
if exists {
forwarderMu.Lock()
delete(localForwarders, key)
forwarderMu.Unlock()
}
forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort)
if err != nil {
return nil, err
}
forwarderMu.Lock()
localForwarders[key] = forwarder
forwarderMu.Unlock()
return forwarder, nil
}
// CloseAllForwarders closes all local forwarders
func CloseAllForwarders() {
forwarderMu.Lock()
defer forwarderMu.Unlock()
for key, forwarder := range localForwarders {
if forwarder != nil {
_ = forwarder.Close()
logger.Infof("已关闭端口转发:%s", key)
}
}
localForwarders = make(map[string]*LocalForwarder)
}
// getSSHClientCacheKey generates a unique cache key for SSH config
func getSSHClientCacheKey(config connection.SSHConfig) string {
return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User)
}
// GetOrCreateSSHClient returns a cached SSH client or creates a new one
func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) {
key := getSSHClientCacheKey(config)
sshClientCacheMu.RLock()
client, exists := sshClientCache[key]
sshClientCacheMu.RUnlock()
if exists && client != nil {
// Test if connection is still alive by creating a test session
session, err := client.NewSession()
if err == nil {
session.Close()
logger.Infof("复用已有 SSH 连接:%s", key)
return client, nil
}
// Connection is dead, remove from cache
logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err)
sshClientCacheMu.Lock()
delete(sshClientCache, key)
sshClientCacheMu.Unlock()
// Try to close the dead client
_ = client.Close()
}
// Create new SSH client
client, err := connectSSH(config)
if err != nil {
return nil, err
}
// Cache the client
sshClientCacheMu.Lock()
sshClientCache[key] = client
sshClientCacheMu.Unlock()
logger.Infof("已缓存 SSH 连接:%s", key)
return client, nil
}
// DialThroughSSH creates a connection through SSH tunnel
// This is a generic dialer that can be used by any database driver
func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) {
client, err := GetOrCreateSSHClient(config)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
conn, err := client.Dial(network, address)
if err != nil {
return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err)
}
logger.Infof("已通过 SSH 隧道连接到:%s", address)
return conn, nil
}
// CloseAllSSHClients closes all cached SSH clients
func CloseAllSSHClients() {
sshClientCacheMu.Lock()
defer sshClientCacheMu.Unlock()
for key, client := range sshClientCache {
if client != nil {
_ = client.Close()
logger.Infof("已关闭 SSH 连接:%s", key)
}
}
sshClientCache = make(map[string]*ssh.Client)
}