mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-12 20:29:43 +08:00
Compare commits
20 Commits
v0.2.2
...
feature/ta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71e5de0cdc | ||
|
|
d8656c6c9c | ||
|
|
443b487a02 | ||
|
|
bac57ebdf0 | ||
|
|
213a33e4f3 | ||
|
|
a00f87582d | ||
|
|
f129623000 | ||
|
|
8dbc97e466 | ||
|
|
4a0db185c0 | ||
|
|
5793f63ac8 | ||
|
|
8aabc67634 | ||
|
|
34c494ce51 | ||
|
|
178de02783 | ||
|
|
94e5b8d2c6 | ||
|
|
89e2247c05 | ||
|
|
b2ede61b79 | ||
|
|
db381ae9d1 | ||
|
|
f946cfd647 | ||
|
|
46c48c5ea8 | ||
|
|
e3bf160072 |
58
.github/ISSUE_TEMPLATE/01-bug_report.yml
vendored
Normal file
58
.github/ISSUE_TEMPLATE/01-bug_report.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: 问题反馈
|
||||
description: 软件问题反馈
|
||||
title: "[Bug] "
|
||||
labels: ["bug"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: searched
|
||||
attributes:
|
||||
label: 已经搜索过 Issues,未发现重复问题*
|
||||
options:
|
||||
- label: 我已经搜索过 Issues,没有发现重复问题
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: system
|
||||
attributes:
|
||||
label: 操作系统及版本
|
||||
placeholder: Windows 10 22H2 / macOS Mojave / Linux
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: 软件安装版本
|
||||
placeholder: v0.2.3
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: 问题简述及复现流程
|
||||
description: 请详细描述你遇到的问题,并提供复现步骤
|
||||
placeholder: |
|
||||
1. 打开软件
|
||||
2. 点击 xxx
|
||||
3. 预期结果是 ...
|
||||
4. 实际结果是 ...
|
||||
5. 截图 ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: 其他补充
|
||||
description: 如果你有额外信息,请在此填写
|
||||
placeholder: 可选
|
||||
|
||||
- type: checkboxes
|
||||
id: pr
|
||||
attributes:
|
||||
label: 是否愿意提交 PR 修复当前 Issue
|
||||
options:
|
||||
- label: 我愿意尝试提交 PR
|
||||
37
.github/ISSUE_TEMPLATE/02-feature_request.yml
vendored
Normal file
37
.github/ISSUE_TEMPLATE/02-feature_request.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: 功能建议
|
||||
description: 添加全新功能或改进现有功能
|
||||
title: "[Enhancement] "
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: searched
|
||||
attributes:
|
||||
label: 已经搜索过 Issues,未发现重复问题*
|
||||
options:
|
||||
- label: 我已经搜索过 Issues,没有发现重复问题
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: feature
|
||||
attributes:
|
||||
label: 功能描述
|
||||
description: 请详细描述你希望添加或改进的功能
|
||||
placeholder: 请描述你想要的功能
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: 其他补充
|
||||
description: 如果你有额外信息,请在此填写
|
||||
placeholder: 可选
|
||||
|
||||
- type: checkboxes
|
||||
id: pr
|
||||
attributes:
|
||||
label: 是否愿意提交 PR 实现当前 Issue
|
||||
options:
|
||||
- label: 我愿意尝试提交 PR
|
||||
30
.github/ISSUE_TEMPLATE/03-generic.yml
vendored
Normal file
30
.github/ISSUE_TEMPLATE/03-generic.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: 其他反馈
|
||||
description: 其他类型反馈、建议或讨论
|
||||
title: "[Question] "
|
||||
labels: ["question"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: searched
|
||||
attributes:
|
||||
label: 已经搜索过 Issues,未发现重复问题*
|
||||
options:
|
||||
- label: 我已经搜索过 Issues,没有发现重复问题
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: 内容
|
||||
description: 请填写你的反馈、建议或讨论内容
|
||||
placeholder: 请描述你的问题或想法
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: 其他补充
|
||||
description: 如果你有额外信息,请在此填写
|
||||
placeholder: 可选
|
||||
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
blank_issues_enabled: false
|
||||
@@ -285,12 +285,12 @@ function App() {
|
||||
title="拖动调整宽度"
|
||||
/>
|
||||
</Sider>
|
||||
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<div style={{ flex: 1, overflow: 'hidden' }}>
|
||||
<TabManager />
|
||||
</div>
|
||||
{isLogPanelOpen && (
|
||||
<LogPanel
|
||||
<Content style={{ background: darkMode ? '#141414' : '#fff', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<TabManager />
|
||||
</div>
|
||||
{isLogPanelOpen && (
|
||||
<LogPanel
|
||||
height={logPanelHeight}
|
||||
onClose={() => setIsLogPanelOpen(false)}
|
||||
onResizeStart={handleLogResizeStart}
|
||||
@@ -343,4 +343,4 @@ function App() {
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
export default App;
|
||||
|
||||
@@ -264,8 +264,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
{useSSH && (
|
||||
<div style={{ padding: '12px', background: '#f5f5f5', borderRadius: 6, marginTop: 12 }}>
|
||||
<div style={{ display: 'flex', gap: 16 }}>
|
||||
<Form.Item name="sshHost" label="SSH 主机" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
|
||||
<Input placeholder="ssh.example.com" />
|
||||
<Form.Item name="sshHost" label="SSH 主机 (域名或IP)" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
|
||||
<Input placeholder="例如: ssh.example.com 或 192.168.1.100" />
|
||||
</Form.Item>
|
||||
<Form.Item name="sshPort" label="端口" rules={[{ required: useSSH, message: '请输入SSH端口' }]} style={{ width: 100 }}>
|
||||
<InputNumber style={{ width: '100%' }} />
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,36 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal, Form, Select, Button, message, Steps, Transfer, Card, Alert, Divider, Typography } from 'antd';
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import { Modal, Form, Select, Button, message, Steps, Transfer, Card, Alert, Divider, Typography, Progress, Checkbox, Table, Drawer, Tabs } from 'antd';
|
||||
import { useStore } from '../store';
|
||||
import { DBGetDatabases, DBGetTables, DataSync } from '../../wailsjs/go/app/App';
|
||||
import { DBGetDatabases, DBGetTables, DataSync, DataSyncAnalyze, DataSyncPreview } from '../../wailsjs/go/app/App';
|
||||
import { SavedConnection } from '../types';
|
||||
import { connection } from '../../wailsjs/go/models';
|
||||
import { EventsOn } from '../../wailsjs/runtime/runtime';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
const { Step } = Steps;
|
||||
const { Option } = Select;
|
||||
|
||||
type SyncLogEvent = { jobId: string; level?: string; message?: string; ts?: number };
|
||||
type SyncProgressEvent = { jobId: string; percent?: number; current?: number; total?: number; table?: string; stage?: string };
|
||||
type SyncLogItem = { level: string; message: string; ts?: number };
|
||||
type TableDiffSummary = {
|
||||
table: string;
|
||||
pkColumn?: string;
|
||||
canSync?: boolean;
|
||||
inserts?: number;
|
||||
updates?: number;
|
||||
deletes?: number;
|
||||
same?: number;
|
||||
message?: string;
|
||||
};
|
||||
type TableOps = {
|
||||
insert: boolean;
|
||||
update: boolean;
|
||||
delete: boolean;
|
||||
selectedInsertPks?: string[];
|
||||
selectedUpdatePks?: string[];
|
||||
selectedDeletePks?: string[];
|
||||
};
|
||||
|
||||
const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, onClose }) => {
|
||||
const connections = useStore((state) => state.connections);
|
||||
const [currentStep, setCurrentStep] = useState(0);
|
||||
@@ -27,8 +49,76 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const [allTables, setAllTables] = useState<string[]>([]);
|
||||
const [selectedTables, setSelectedTables] = useState<string[]>([]);
|
||||
|
||||
// Options
|
||||
const [syncContent, setSyncContent] = useState<'data' | 'schema' | 'both'>('data');
|
||||
const [syncMode, setSyncMode] = useState<string>('insert_update');
|
||||
const [autoAddColumns, setAutoAddColumns] = useState<boolean>(true);
|
||||
const [showSameTables, setShowSameTables] = useState<boolean>(false);
|
||||
const [analyzing, setAnalyzing] = useState<boolean>(false);
|
||||
const [diffTables, setDiffTables] = useState<TableDiffSummary[]>([]);
|
||||
const [tableOptions, setTableOptions] = useState<Record<string, TableOps>>({});
|
||||
|
||||
const [previewOpen, setPreviewOpen] = useState(false);
|
||||
const [previewTable, setPreviewTable] = useState<string>('');
|
||||
const [previewLoading, setPreviewLoading] = useState(false);
|
||||
const [previewData, setPreviewData] = useState<any>(null);
|
||||
|
||||
// Step 3: Result
|
||||
const [syncResult, setSyncResult] = useState<any>(null);
|
||||
const [syncing, setSyncing] = useState(false);
|
||||
const [syncLogs, setSyncLogs] = useState<SyncLogItem[]>([]);
|
||||
const [syncProgress, setSyncProgress] = useState<{ percent: number; current: number; total: number; table: string; stage: string }>({
|
||||
percent: 0,
|
||||
current: 0,
|
||||
total: 0,
|
||||
table: '',
|
||||
stage: ''
|
||||
});
|
||||
const jobIdRef = useRef<string>('');
|
||||
const logBoxRef = useRef<HTMLDivElement>(null);
|
||||
const autoScrollRef = useRef(true);
|
||||
|
||||
const normalizeConnConfig = (conn: SavedConnection, database?: string) => ({
|
||||
...conn.config,
|
||||
port: Number((conn.config as any).port),
|
||||
password: conn.config.password || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
|
||||
database: typeof database === 'string' ? database : (conn.config.database || ""),
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) return;
|
||||
|
||||
const offLog = EventsOn('sync:log', (event: SyncLogEvent) => {
|
||||
if (!event || event.jobId !== jobIdRef.current) return;
|
||||
const msg = String(event.message || '').trim();
|
||||
if (!msg) return;
|
||||
setSyncLogs(prev => [...prev, { level: String(event.level || 'info'), message: msg, ts: event.ts }]);
|
||||
});
|
||||
|
||||
const offProgress = EventsOn('sync:progress', (event: SyncProgressEvent) => {
|
||||
if (!event || event.jobId !== jobIdRef.current) return;
|
||||
setSyncProgress(prev => ({
|
||||
percent: typeof event.percent === 'number' ? event.percent : prev.percent,
|
||||
current: typeof event.current === 'number' ? event.current : prev.current,
|
||||
total: typeof event.total === 'number' ? event.total : prev.total,
|
||||
table: typeof event.table === 'string' ? event.table : prev.table,
|
||||
stage: typeof event.stage === 'string' ? event.stage : prev.stage,
|
||||
}));
|
||||
});
|
||||
|
||||
return () => {
|
||||
offLog();
|
||||
offProgress();
|
||||
};
|
||||
}, [open]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!logBoxRef.current) return;
|
||||
if (!autoScrollRef.current) return;
|
||||
logBoxRef.current.scrollTop = logBoxRef.current.scrollHeight;
|
||||
}, [syncLogs]);
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
@@ -38,7 +128,23 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
setSourceDb('');
|
||||
setTargetDb('');
|
||||
setSelectedTables([]);
|
||||
setSyncContent('data');
|
||||
setSyncMode('insert_update');
|
||||
setAutoAddColumns(true);
|
||||
setShowSameTables(false);
|
||||
setAnalyzing(false);
|
||||
setDiffTables([]);
|
||||
setTableOptions({});
|
||||
setPreviewOpen(false);
|
||||
setPreviewTable('');
|
||||
setPreviewLoading(false);
|
||||
setPreviewData(null);
|
||||
setSyncResult(null);
|
||||
setSyncing(false);
|
||||
setSyncLogs([]);
|
||||
setSyncProgress({ percent: 0, current: 0, total: 0, table: '', stage: '' });
|
||||
jobIdRef.current = '';
|
||||
autoScrollRef.current = true;
|
||||
}
|
||||
}, [open]);
|
||||
|
||||
@@ -49,7 +155,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
if (conn) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await DBGetDatabases(conn.config as any);
|
||||
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
|
||||
if (res.success) {
|
||||
setSourceDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username));
|
||||
}
|
||||
@@ -65,7 +171,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
if (conn) {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await DBGetDatabases(conn.config as any);
|
||||
const res = await DBGetDatabases(normalizeConnConfig(conn) as any);
|
||||
if (res.success) {
|
||||
setTargetDbs((res.data as any[]).map((r: any) => r.Database || r.database || r.username));
|
||||
}
|
||||
@@ -83,7 +189,7 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
try {
|
||||
const conn = connections.find(c => c.id === sourceConnId);
|
||||
if (conn) {
|
||||
const config = { ...conn.config, database: sourceDb };
|
||||
const config = normalizeConnConfig(conn, sourceDb);
|
||||
const res = await DBGetTables(config as any, sourceDb);
|
||||
if (res.success) {
|
||||
// DBGetTables returns [{Table: "name"}, ...]
|
||||
@@ -98,36 +204,221 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const runSync = async () => {
|
||||
const updateTableOption = (table: string, key: keyof TableOps, value: any) => {
|
||||
setTableOptions(prev => ({
|
||||
...prev,
|
||||
[table]: { ...(prev[table] || { insert: true, update: true, delete: false }), [key]: value }
|
||||
}));
|
||||
};
|
||||
|
||||
const analyzeDiff = async () => {
|
||||
if (selectedTables.length === 0) return;
|
||||
if (!sourceConnId || !targetConnId) return message.error("Select connections first");
|
||||
if (!sourceDb || !targetDb) return message.error("Select databases first");
|
||||
|
||||
setLoading(true);
|
||||
setAnalyzing(true);
|
||||
setDiffTables([]);
|
||||
setTableOptions({});
|
||||
setSyncLogs([]);
|
||||
|
||||
const sConn = connections.find(c => c.id === sourceConnId)!;
|
||||
const tConn = connections.find(c => c.id === targetConnId)!;
|
||||
const jobId = `analyze-${Date.now()}-${Math.random().toString(16).slice(2, 8)}`;
|
||||
jobIdRef.current = jobId;
|
||||
autoScrollRef.current = true;
|
||||
setSyncProgress({ percent: 0, current: 0, total: selectedTables.length, table: '', stage: '差异分析' });
|
||||
|
||||
const config = {
|
||||
sourceConfig: normalizeConnConfig(sConn, sourceDb),
|
||||
targetConfig: normalizeConnConfig(tConn, targetDb),
|
||||
tables: selectedTables,
|
||||
content: syncContent,
|
||||
mode: "insert_update",
|
||||
autoAddColumns,
|
||||
jobId,
|
||||
};
|
||||
|
||||
try {
|
||||
const res = await DataSyncAnalyze(config as any);
|
||||
if (res.success) {
|
||||
const tables = ((res.data as any)?.tables || []) as TableDiffSummary[];
|
||||
setDiffTables(tables);
|
||||
const init: Record<string, TableOps> = {};
|
||||
tables.forEach(t => {
|
||||
const can = !!t.canSync;
|
||||
init[t.table] = {
|
||||
insert: can,
|
||||
update: can,
|
||||
delete: false,
|
||||
selectedInsertPks: [],
|
||||
selectedUpdatePks: [],
|
||||
selectedDeletePks: [],
|
||||
};
|
||||
});
|
||||
setTableOptions(init);
|
||||
message.success("差异分析完成");
|
||||
} else {
|
||||
message.error(res.message || "差异分析失败");
|
||||
}
|
||||
} catch (e: any) {
|
||||
message.error("差异分析失败: " + (e?.message || ""));
|
||||
}
|
||||
|
||||
setLoading(false);
|
||||
setAnalyzing(false);
|
||||
};
|
||||
|
||||
const openPreview = async (table: string) => {
|
||||
if (!table) return;
|
||||
const sConn = connections.find(c => c.id === sourceConnId)!;
|
||||
const tConn = connections.find(c => c.id === targetConnId)!;
|
||||
|
||||
setPreviewOpen(true);
|
||||
setPreviewTable(table);
|
||||
setPreviewLoading(true);
|
||||
setPreviewData(null);
|
||||
|
||||
const config = {
|
||||
sourceConfig: normalizeConnConfig(sConn, sourceDb),
|
||||
targetConfig: normalizeConnConfig(tConn, targetDb),
|
||||
tables: selectedTables,
|
||||
content: "data",
|
||||
mode: "insert_update",
|
||||
autoAddColumns,
|
||||
};
|
||||
|
||||
try {
|
||||
const res = await DataSyncPreview(config as any, table, 200);
|
||||
if (res.success) {
|
||||
setPreviewData(res.data);
|
||||
} else {
|
||||
message.error(res.message || "加载差异预览失败");
|
||||
}
|
||||
} catch (e: any) {
|
||||
message.error("加载差异预览失败: " + (e?.message || ""));
|
||||
}
|
||||
|
||||
setPreviewLoading(false);
|
||||
};
|
||||
|
||||
const runSync = async () => {
|
||||
if (syncContent !== 'schema' && diffTables.length === 0) {
|
||||
message.error("请先对比差异,再开始同步");
|
||||
return;
|
||||
}
|
||||
if (syncContent !== 'schema' && syncMode === 'full_overwrite') {
|
||||
const ok = await new Promise<boolean>((resolve) => {
|
||||
Modal.confirm({
|
||||
title: '确认全量覆盖',
|
||||
content: '全量覆盖会清空目标表数据后再插入,请确认已备份目标库。',
|
||||
okText: '继续执行',
|
||||
cancelText: '取消',
|
||||
onOk: () => resolve(true),
|
||||
onCancel: () => resolve(false),
|
||||
});
|
||||
});
|
||||
if (!ok) return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
setSyncing(true);
|
||||
setCurrentStep(2);
|
||||
setSyncResult(null);
|
||||
setSyncLogs([]);
|
||||
|
||||
const sConn = connections.find(c => c.id === sourceConnId)!;
|
||||
const tConn = connections.find(c => c.id === targetConnId)!;
|
||||
|
||||
const jobId = `sync-${Date.now()}-${Math.random().toString(16).slice(2, 8)}`;
|
||||
jobIdRef.current = jobId;
|
||||
autoScrollRef.current = true;
|
||||
setSyncProgress({
|
||||
percent: 0,
|
||||
current: 0,
|
||||
total: selectedTables.length,
|
||||
table: '',
|
||||
stage: '准备开始',
|
||||
});
|
||||
|
||||
const config = {
|
||||
sourceConfig: { ...sConn.config, database: sourceDb },
|
||||
targetConfig: { ...tConn.config, database: targetDb },
|
||||
sourceConfig: {
|
||||
...sConn.config,
|
||||
port: Number((sConn.config as any).port),
|
||||
password: sConn.config.password || "",
|
||||
useSSH: sConn.config.useSSH || false,
|
||||
ssh: sConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
|
||||
database: sourceDb,
|
||||
},
|
||||
targetConfig: {
|
||||
...tConn.config,
|
||||
port: Number((tConn.config as any).port),
|
||||
password: tConn.config.password || "",
|
||||
useSSH: tConn.config.useSSH || false,
|
||||
ssh: tConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
|
||||
database: targetDb,
|
||||
},
|
||||
tables: selectedTables,
|
||||
mode: "insert_update"
|
||||
content: syncContent,
|
||||
mode: syncMode,
|
||||
autoAddColumns,
|
||||
tableOptions,
|
||||
jobId,
|
||||
};
|
||||
|
||||
try {
|
||||
const res = await DataSync(config as any);
|
||||
setSyncResult(res);
|
||||
setCurrentStep(2);
|
||||
if (Array.isArray(res?.logs) && res.logs.length > 0) {
|
||||
setSyncLogs(prev => {
|
||||
if (prev.length > 0) return prev;
|
||||
return (res.logs as string[]).map((log) => {
|
||||
const msg = String(log || '').trim();
|
||||
if (msg.includes('致命错误') || msg.includes('失败')) return { level: 'error', message: msg };
|
||||
if (msg.includes('跳过') || msg.includes('警告')) return { level: 'warn', message: msg };
|
||||
return { level: 'info', message: msg };
|
||||
});
|
||||
});
|
||||
}
|
||||
} catch (e) {
|
||||
message.error("Sync execution failed");
|
||||
setSyncResult({ success: false, message: "同步执行失败", logs: [] });
|
||||
}
|
||||
setLoading(false);
|
||||
setSyncing(false);
|
||||
};
|
||||
|
||||
const renderSyncLogItem = (item: SyncLogItem) => {
|
||||
const level = String(item.level || 'info').toLowerCase();
|
||||
const color = level === 'error' ? '#ff4d4f' : (level === 'warn' ? '#faad14' : '#595959');
|
||||
const label = level === 'error' ? '错误' : (level === 'warn' ? '警告' : '信息');
|
||||
const timeText = typeof item.ts === 'number' ? new Date(item.ts).toLocaleTimeString('zh-CN', { hour12: false }) : '';
|
||||
return (
|
||||
<div style={{ display: 'flex', gap: 8, alignItems: 'flex-start' }}>
|
||||
<span style={{ color, flex: '0 0 auto' }}>● {label}</span>
|
||||
{timeText && <span style={{ color: '#8c8c8c', flex: '0 0 auto' }}>{timeText}</span>}
|
||||
<span style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>{item.message}</span>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Modal
|
||||
title="数据同步"
|
||||
open={open}
|
||||
onCancel={onClose}
|
||||
width={800}
|
||||
footer={null}
|
||||
destroyOnHidden
|
||||
title="数据同步"
|
||||
open={open}
|
||||
onCancel={() => {
|
||||
if (syncing) {
|
||||
message.warning("同步执行中,暂不支持关闭");
|
||||
return;
|
||||
}
|
||||
onClose();
|
||||
}}
|
||||
width={800}
|
||||
footer={null}
|
||||
destroyOnHidden
|
||||
closable={!syncing}
|
||||
maskClosable={!syncing}
|
||||
>
|
||||
<Steps current={currentStep} style={{ marginBottom: 24 }}>
|
||||
<Step title="配置源与目标" />
|
||||
@@ -137,34 +428,67 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
|
||||
{/* STEP 1: CONFIG */}
|
||||
{currentStep === 0 && (
|
||||
<div style={{ display: 'flex', gap: 24, justifyContent: 'center' }}>
|
||||
<Card title="源数据库" style={{ width: 350 }}>
|
||||
<div>
|
||||
<div style={{ display: 'flex', gap: 24, justifyContent: 'center' }}>
|
||||
<Card title="源数据库" style={{ width: 350 }}>
|
||||
<Form layout="vertical">
|
||||
<Form.Item label="连接">
|
||||
<Select value={sourceConnId} onChange={handleSourceConnChange}>
|
||||
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item label="数据库">
|
||||
<Select value={sourceDb} onChange={setSourceDb} showSearch>
|
||||
{sourceDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Card>
|
||||
<div style={{ display: 'flex', alignItems: 'center' }}>至</div>
|
||||
<Card title="目标数据库" style={{ width: 350 }}>
|
||||
<Form layout="vertical">
|
||||
<Form.Item label="连接">
|
||||
<Select value={targetConnId} onChange={handleTargetConnChange}>
|
||||
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item label="数据库">
|
||||
<Select value={targetDb} onChange={setTargetDb} showSearch>
|
||||
{targetDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<Card title="同步选项" style={{ marginTop: 16 }}>
|
||||
<Form layout="vertical">
|
||||
<Form.Item label="连接">
|
||||
<Select value={sourceConnId} onChange={handleSourceConnChange}>
|
||||
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
|
||||
<Form.Item label="同步内容">
|
||||
<Select value={syncContent} onChange={setSyncContent}>
|
||||
<Option value="data">仅同步数据</Option>
|
||||
<Option value="schema">仅同步结构</Option>
|
||||
<Option value="both">同步结构 + 数据</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item label="数据库">
|
||||
<Select value={sourceDb} onChange={setSourceDb} showSearch>
|
||||
{sourceDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
|
||||
<Form.Item label="同步模式">
|
||||
<Select value={syncMode} onChange={setSyncMode} disabled={syncContent === 'schema'}>
|
||||
<Option value="insert_update">增量同步(对比差异,按插入/更新/删除勾选执行)</Option>
|
||||
<Option value="insert_only">仅插入(不对比目标;无主键表将跳过)</Option>
|
||||
<Option value="full_overwrite">全量覆盖(清空目标表后插入)</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Card>
|
||||
<div style={{ display: 'flex', alignItems: 'center' }}>至</div>
|
||||
<Card title="目标数据库" style={{ width: 350 }}>
|
||||
<Form layout="vertical">
|
||||
<Form.Item label="连接">
|
||||
<Select value={targetConnId} onChange={handleTargetConnChange}>
|
||||
{connections.map(c => <Option key={c.id} value={c.id}>{c.name} ({c.config.type})</Option>)}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item label="数据库">
|
||||
<Select value={targetDb} onChange={setTargetDb} showSearch>
|
||||
{targetDbs.map(d => <Option key={d} value={d}>{d}</Option>)}
|
||||
</Select>
|
||||
<Form.Item>
|
||||
<Checkbox checked={autoAddColumns} onChange={(e) => setAutoAddColumns(e.target.checked)}>
|
||||
自动补齐目标表缺失字段(仅 MySQL 目标)
|
||||
</Checkbox>
|
||||
</Form.Item>
|
||||
{syncContent !== 'schema' && syncMode === 'full_overwrite' && (
|
||||
<Alert
|
||||
type="warning"
|
||||
showIcon
|
||||
message="全量覆盖会清空目标表数据,请谨慎使用。"
|
||||
/>
|
||||
)}
|
||||
</Form>
|
||||
</Card>
|
||||
</div>
|
||||
@@ -172,32 +496,155 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
|
||||
{/* STEP 2: TABLES */}
|
||||
{currentStep === 1 && (
|
||||
<div style={{ height: 400 }}>
|
||||
<Text type="secondary">请选择需要同步的表:</Text>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||
<Text type="secondary">请选择需要同步的表:</Text>
|
||||
<Checkbox checked={showSameTables} onChange={(e) => setShowSameTables(e.target.checked)}>
|
||||
显示相同表
|
||||
</Checkbox>
|
||||
</div>
|
||||
<Transfer
|
||||
dataSource={allTables.map(t => ({ key: t, title: t }))}
|
||||
titles={['源表', '已选表']}
|
||||
targetKeys={selectedTables}
|
||||
onChange={(keys) => setSelectedTables(keys as string[])}
|
||||
render={item => item.title}
|
||||
listStyle={{ width: 350, height: 350, marginTop: 12 }}
|
||||
listStyle={{ width: 350, height: 280, marginTop: 0 }}
|
||||
locale={{ itemUnit: '项', itemsUnit: '项', searchPlaceholder: '搜索表', notFoundContent: '暂无数据' }}
|
||||
/>
|
||||
|
||||
{diffTables.length > 0 && (
|
||||
<div>
|
||||
<Divider orientation="left">对比结果</Divider>
|
||||
<Table
|
||||
size="small"
|
||||
pagination={false}
|
||||
rowKey={(r: any) => r.table}
|
||||
dataSource={diffTables.filter(t => {
|
||||
const ins = Number(t.inserts || 0);
|
||||
const upd = Number(t.updates || 0);
|
||||
const del = Number(t.deletes || 0);
|
||||
const same = Number(t.same || 0);
|
||||
const msg = String(t.message || '').trim();
|
||||
const can = !!t.canSync;
|
||||
if (showSameTables) return true;
|
||||
if (!can) return true;
|
||||
if (msg) return true;
|
||||
return ins > 0 || upd > 0 || del > 0 || same === 0;
|
||||
})}
|
||||
columns={[
|
||||
{ title: '表名', dataIndex: 'table', key: 'table', ellipsis: true },
|
||||
{
|
||||
title: '插入',
|
||||
key: 'inserts',
|
||||
width: 90,
|
||||
render: (_: any, r: any) => {
|
||||
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
|
||||
const disabled = !r.canSync || analyzing || Number(r.inserts || 0) === 0;
|
||||
return (
|
||||
<Checkbox
|
||||
checked={!!ops.insert}
|
||||
disabled={disabled}
|
||||
onChange={(e) => updateTableOption(r.table, 'insert', e.target.checked)}
|
||||
>
|
||||
{Number(r.inserts || 0)}
|
||||
</Checkbox>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '更新',
|
||||
key: 'updates',
|
||||
width: 90,
|
||||
render: (_: any, r: any) => {
|
||||
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
|
||||
const disabled = !r.canSync || analyzing || Number(r.updates || 0) === 0;
|
||||
return (
|
||||
<Checkbox
|
||||
checked={!!ops.update}
|
||||
disabled={disabled}
|
||||
onChange={(e) => updateTableOption(r.table, 'update', e.target.checked)}
|
||||
>
|
||||
{Number(r.updates || 0)}
|
||||
</Checkbox>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '删除',
|
||||
key: 'deletes',
|
||||
width: 90,
|
||||
render: (_: any, r: any) => {
|
||||
const ops = tableOptions[r.table] || { insert: true, update: true, delete: false };
|
||||
const disabled = !r.canSync || analyzing || Number(r.deletes || 0) === 0;
|
||||
return (
|
||||
<Checkbox
|
||||
checked={!!ops.delete}
|
||||
disabled={disabled}
|
||||
onChange={(e) => updateTableOption(r.table, 'delete', e.target.checked)}
|
||||
>
|
||||
{Number(r.deletes || 0)}
|
||||
</Checkbox>
|
||||
);
|
||||
}
|
||||
},
|
||||
{ title: '相同', dataIndex: 'same', key: 'same', width: 70, render: (v: any) => Number(v || 0) },
|
||||
{ title: '消息', dataIndex: 'message', key: 'message', ellipsis: true, render: (v: any) => (v ? String(v) : '') },
|
||||
{
|
||||
title: '预览',
|
||||
key: 'preview',
|
||||
width: 80,
|
||||
render: (_: any, r: any) => {
|
||||
const can = !!r.canSync;
|
||||
const hasDiff = Number(r.inserts || 0) + Number(r.updates || 0) + Number(r.deletes || 0) > 0;
|
||||
return (
|
||||
<Button size="small" disabled={!can || !hasDiff || analyzing} onClick={() => openPreview(r.table)}>
|
||||
查看
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* STEP 3: RESULT */}
|
||||
{currentStep === 2 && syncResult && (
|
||||
{currentStep === 2 && (
|
||||
<div>
|
||||
<Alert
|
||||
message={syncResult.success ? "同步完成" : "同步失败"}
|
||||
description={syncResult.message || `成功同步 ${syncResult.tablesSynced} 张表. 插入: ${syncResult.rowsInserted}, 更新: ${syncResult.rowsUpdated}`}
|
||||
type={syncResult.success ? "success" : "error"}
|
||||
showIcon
|
||||
<Alert
|
||||
message={syncing ? "正在同步" : (syncResult?.success ? "同步完成" : "同步失败")}
|
||||
description={
|
||||
syncing
|
||||
? `当前阶段:${syncProgress.stage || '执行中'}${syncProgress.table ? `,表:${syncProgress.table}` : ''}`
|
||||
: (syncResult?.message || `成功同步 ${syncResult?.tablesSynced || 0} 张表. 插入: ${syncResult?.rowsInserted || 0}, 更新: ${syncResult?.rowsUpdated || 0}`)
|
||||
}
|
||||
type={syncing ? "info" : (syncResult?.success ? "success" : "error")}
|
||||
showIcon
|
||||
/>
|
||||
|
||||
<div style={{ marginTop: 12 }}>
|
||||
<Progress
|
||||
percent={syncProgress.percent}
|
||||
status={syncing ? "active" : (syncResult?.success ? "success" : "exception")}
|
||||
format={() => `${syncProgress.current}/${syncProgress.total}`}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Divider orientation="left">日志</Divider>
|
||||
<div style={{ background: '#f5f5f5', padding: 12, height: 300, overflowY: 'auto', fontFamily: 'monospace' }}>
|
||||
{syncResult.logs.map((log: string, i: number) => <div key={i}>{log}</div>)}
|
||||
<div
|
||||
ref={logBoxRef}
|
||||
onScroll={() => {
|
||||
const el = logBoxRef.current;
|
||||
if (!el) return;
|
||||
const nearBottom = el.scrollHeight - el.scrollTop - el.clientHeight < 40;
|
||||
autoScrollRef.current = nearBottom;
|
||||
}}
|
||||
style={{ background: '#f5f5f5', padding: 12, height: 300, overflowY: 'auto', fontFamily: 'monospace' }}
|
||||
>
|
||||
{syncLogs.map((item, i: number) => <div key={i}>{renderSyncLogItem(item)}</div>)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
@@ -206,20 +653,154 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
{currentStep === 0 && (
|
||||
<Button type="primary" onClick={nextToTables} loading={loading}>下一步</Button>
|
||||
)}
|
||||
{currentStep === 1 && (
|
||||
<>
|
||||
<Button onClick={() => setCurrentStep(0)} style={{ marginRight: 8 }}>上一步</Button>
|
||||
<Button type="primary" onClick={runSync} loading={loading} disabled={selectedTables.length === 0}>开始同步</Button>
|
||||
{currentStep === 1 && (
|
||||
<>
|
||||
<Button onClick={() => setCurrentStep(0)} style={{ marginRight: 8 }}>上一步</Button>
|
||||
<Button onClick={analyzeDiff} loading={loading} disabled={syncContent === 'schema' || selectedTables.length === 0 || analyzing} style={{ marginRight: 8 }}>
|
||||
对比差异
|
||||
</Button>
|
||||
<Button
|
||||
type="primary"
|
||||
onClick={runSync}
|
||||
loading={loading}
|
||||
disabled={selectedTables.length === 0 || (syncContent !== 'schema' && diffTables.length === 0)}
|
||||
>
|
||||
开始同步
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
{currentStep === 2 && (
|
||||
<>
|
||||
<Button onClick={() => setCurrentStep(1)} style={{ marginRight: 8 }}>继续同步</Button>
|
||||
<Button type="primary" onClick={onClose}>关闭</Button>
|
||||
<Button disabled={syncing} onClick={() => setCurrentStep(1)} style={{ marginRight: 8 }}>继续同步</Button>
|
||||
<Button type="primary" disabled={syncing} onClick={onClose}>关闭</Button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
<Drawer
|
||||
title={`差异预览:${previewTable}`}
|
||||
open={previewOpen}
|
||||
onClose={() => { setPreviewOpen(false); setPreviewTable(''); setPreviewData(null); }}
|
||||
width={900}
|
||||
>
|
||||
{previewLoading && <Alert type="info" showIcon message="正在加载差异预览..." />}
|
||||
{!previewLoading && previewData && (
|
||||
<div>
|
||||
<Alert
|
||||
type="info"
|
||||
showIcon
|
||||
message={`插入 ${previewData.totalInserts || 0},更新 ${previewData.totalUpdates || 0},删除 ${previewData.totalDeletes || 0}(预览最多展示 200 条/类型)`}
|
||||
/>
|
||||
<Divider />
|
||||
<Tabs
|
||||
items={[
|
||||
{
|
||||
key: 'insert',
|
||||
label: `插入(${previewData.totalInserts || 0})`,
|
||||
children: (
|
||||
<div>
|
||||
<Text type="secondary">未勾选任何行表示“同步全部插入差异”;如不想执行插入请在对比结果中取消勾选“插入”。</Text>
|
||||
<Table
|
||||
size="small"
|
||||
style={{ marginTop: 8 }}
|
||||
rowKey={(r: any) => r.pk}
|
||||
dataSource={(previewData.inserts || []).map((r: any) => ({ ...r, key: r.pk }))}
|
||||
pagination={false}
|
||||
rowSelection={{
|
||||
selectedRowKeys: (tableOptions[previewTable]?.selectedInsertPks || []) as any,
|
||||
onChange: (keys) => updateTableOption(previewTable, 'selectedInsertPks', keys as string[]),
|
||||
getCheckboxProps: () => ({ disabled: !tableOptions[previewTable]?.insert }),
|
||||
}}
|
||||
columns={[
|
||||
{ title: previewData.pkColumn || '主键', dataIndex: 'pk', key: 'pk', width: 200, ellipsis: true },
|
||||
{ title: '数据', dataIndex: 'row', key: 'row', render: (v: any) => <pre style={{ margin: 0, maxHeight: 140, overflow: 'auto' }}>{JSON.stringify(v, null, 2)}</pre> }
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
{
|
||||
key: 'update',
|
||||
label: `更新(${previewData.totalUpdates || 0})`,
|
||||
children: (
|
||||
<div>
|
||||
<Text type="secondary">未勾选任何行表示“同步全部更新差异”;如不想执行更新请在对比结果中取消勾选“更新”。</Text>
|
||||
<Table
|
||||
size="small"
|
||||
style={{ marginTop: 8 }}
|
||||
rowKey={(r: any) => r.pk}
|
||||
dataSource={(previewData.updates || []).map((r: any) => ({ ...r, key: r.pk }))}
|
||||
pagination={false}
|
||||
rowSelection={{
|
||||
selectedRowKeys: (tableOptions[previewTable]?.selectedUpdatePks || []) as any,
|
||||
onChange: (keys) => updateTableOption(previewTable, 'selectedUpdatePks', keys as string[]),
|
||||
getCheckboxProps: () => ({ disabled: !tableOptions[previewTable]?.update }),
|
||||
}}
|
||||
columns={[
|
||||
{ title: previewData.pkColumn || '主键', dataIndex: 'pk', key: 'pk', width: 200, ellipsis: true },
|
||||
{ title: '变更字段', dataIndex: 'changedColumns', key: 'changedColumns', render: (v: any) => Array.isArray(v) ? v.join(', ') : '' },
|
||||
{
|
||||
title: '详情',
|
||||
key: 'detail',
|
||||
width: 80,
|
||||
render: (_: any, r: any) => (
|
||||
<Button size="small" onClick={() => {
|
||||
Modal.info({
|
||||
title: `更新详情:${previewTable} / ${r.pk}`,
|
||||
width: 900,
|
||||
content: (
|
||||
<div style={{ display: 'flex', gap: 12 }}>
|
||||
<div style={{ flex: 1 }}>
|
||||
<Title level={5}>源</Title>
|
||||
<pre style={{ maxHeight: 360, overflow: 'auto', background: '#f5f5f5', padding: 8 }}>{JSON.stringify(r.source, null, 2)}</pre>
|
||||
</div>
|
||||
<div style={{ flex: 1 }}>
|
||||
<Title level={5}>目标</Title>
|
||||
<pre style={{ maxHeight: 360, overflow: 'auto', background: '#f5f5f5', padding: 8 }}>{JSON.stringify(r.target, null, 2)}</pre>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
});
|
||||
}}>查看</Button>
|
||||
)
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
{
|
||||
key: 'delete',
|
||||
label: `删除(${previewData.totalDeletes || 0})`,
|
||||
children: (
|
||||
<div>
|
||||
<Alert type="warning" showIcon message="删除默认不勾选。请确认业务允许后再开启删除操作。" />
|
||||
<Text type="secondary">未勾选任何行表示“同步全部删除差异”;如不想执行删除请在对比结果中取消勾选“删除”。</Text>
|
||||
<Table
|
||||
size="small"
|
||||
style={{ marginTop: 8 }}
|
||||
rowKey={(r: any) => r.pk}
|
||||
dataSource={(previewData.deletes || []).map((r: any) => ({ ...r, key: r.pk }))}
|
||||
pagination={false}
|
||||
rowSelection={{
|
||||
selectedRowKeys: (tableOptions[previewTable]?.selectedDeletePks || []) as any,
|
||||
onChange: (keys) => updateTableOption(previewTable, 'selectedDeletePks', keys as string[]),
|
||||
getCheckboxProps: () => ({ disabled: !tableOptions[previewTable]?.delete }),
|
||||
}}
|
||||
columns={[
|
||||
{ title: previewData.pkColumn || '主键', dataIndex: 'pk', key: 'pk', width: 200, ellipsis: true },
|
||||
{ title: '数据', dataIndex: 'row', key: 'row', render: (v: any) => <pre style={{ margin: 0, maxHeight: 140, overflow: 'auto' }}>{JSON.stringify(v, null, 2)}</pre> }
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Drawer>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import React, { useEffect, useState, useCallback, useRef } from 'react';
|
||||
import { message } from 'antd';
|
||||
import { TabData, ColumnDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import DataGrid from './DataGrid';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { buildWhereSQL, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
|
||||
|
||||
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [data, setData] = useState<any[]>([]);
|
||||
@@ -11,11 +12,17 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [pkColumns, setPkColumns] = useState<string[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const { connections, addSqlLog } = useStore();
|
||||
const fetchSeqRef = useRef(0);
|
||||
const countSeqRef = useRef(0);
|
||||
const countKeyRef = useRef<string>('');
|
||||
const pkSeqRef = useRef(0);
|
||||
const pkKeyRef = useRef<string>('');
|
||||
|
||||
const [pagination, setPagination] = useState({
|
||||
current: 1,
|
||||
pageSize: 100,
|
||||
total: 0
|
||||
total: 0,
|
||||
totalKnown: false
|
||||
});
|
||||
|
||||
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
|
||||
@@ -23,12 +30,20 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [showFilter, setShowFilter] = useState(false);
|
||||
const [filterConditions, setFilterConditions] = useState<any[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
setPkColumns([]);
|
||||
pkKeyRef.current = '';
|
||||
countKeyRef.current = '';
|
||||
setPagination(prev => ({ ...prev, current: 1, total: 0, totalKnown: false }));
|
||||
}, [tab.connectionId, tab.dbName, tab.tableName]);
|
||||
|
||||
const fetchData = useCallback(async (page = pagination.current, size = pagination.pageSize) => {
|
||||
const seq = ++fetchSeqRef.current;
|
||||
setLoading(true);
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) {
|
||||
message.error("Connection not found");
|
||||
setLoading(false);
|
||||
if (fetchSeqRef.current === seq) setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -41,68 +56,31 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const quoteIdentPart = (ident: string) => {
|
||||
if (!ident) return ident;
|
||||
if (config.type === 'mysql') return `\`${ident.replace(/`/g, '``')}\``;
|
||||
return `"${ident.replace(/"/g, '""')}"`;
|
||||
};
|
||||
const quoteQualifiedIdent = (ident: string) => {
|
||||
const raw = (ident || '').trim();
|
||||
if (!raw) return raw;
|
||||
const parts = raw.split('.').filter(Boolean);
|
||||
if (parts.length <= 1) return quoteIdentPart(raw);
|
||||
return parts.map(quoteIdentPart).join('.');
|
||||
};
|
||||
const escapeLiteral = (val: string) => val.replace(/'/g, "''");
|
||||
const dbType = config.type || '';
|
||||
|
||||
const dbName = tab.dbName || '';
|
||||
const tableName = tab.tableName || '';
|
||||
|
||||
const whereParts: string[] = [];
|
||||
filterConditions.forEach(cond => {
|
||||
if (cond.column && cond.value) {
|
||||
if (cond.op === 'LIKE') {
|
||||
whereParts.push(`${quoteIdentPart(cond.column)} LIKE '%${escapeLiteral(cond.value)}%'`);
|
||||
} else {
|
||||
whereParts.push(`${quoteIdentPart(cond.column)} ${cond.op} '${escapeLiteral(cond.value)}'`);
|
||||
}
|
||||
}
|
||||
});
|
||||
const whereSQL = whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : "";
|
||||
const whereSQL = buildWhereSQL(dbType, filterConditions);
|
||||
|
||||
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(tableName)} ${whereSQL}`;
|
||||
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
|
||||
let sql = `SELECT * FROM ${quoteQualifiedIdent(tableName)} ${whereSQL}`;
|
||||
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
if (sortInfo && sortInfo.order) {
|
||||
sql += ` ORDER BY ${quoteIdentPart(sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
sql += ` ORDER BY ${quoteIdentPart(dbType, sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
}
|
||||
const offset = (page - 1) * size;
|
||||
sql += ` LIMIT ${size} OFFSET ${offset}`;
|
||||
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
|
||||
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
|
||||
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const pCount = DBQuery(config as any, dbName, countSql);
|
||||
const pData = DBQuery(config as any, dbName, sql);
|
||||
|
||||
let pCols = null;
|
||||
if (pkColumns.length === 0) {
|
||||
pCols = DBGetColumns(config as any, dbName, tableName);
|
||||
}
|
||||
|
||||
const [resCount, resData] = await Promise.all([pCount, pData]);
|
||||
const resData = await pData;
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
// Log Execution
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-count`,
|
||||
timestamp: Date.now(),
|
||||
sql: countSql,
|
||||
status: resCount.success ? 'success' : 'error',
|
||||
duration: duration / 2, // Estimate
|
||||
message: resCount.success ? '' : resCount.message,
|
||||
dbName
|
||||
});
|
||||
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-data`,
|
||||
timestamp: Date.now(),
|
||||
@@ -114,36 +92,101 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
dbName
|
||||
});
|
||||
|
||||
if (pCols) {
|
||||
const resCols = await pCols;
|
||||
if (resCols.success) {
|
||||
const pks = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name);
|
||||
setPkColumns(pks);
|
||||
if (pkColumns.length === 0) {
|
||||
const pkKey = `${tab.connectionId}|${dbName}|${tableName}`;
|
||||
if (pkKeyRef.current !== pkKey) {
|
||||
pkKeyRef.current = pkKey;
|
||||
const pkSeq = ++pkSeqRef.current;
|
||||
DBGetColumns(config as any, dbName, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (pkSeqRef.current !== pkSeq) return;
|
||||
if (pkKeyRef.current !== pkKey) return;
|
||||
if (!resCols?.success) return;
|
||||
const pks = (resCols.data as ColumnDefinition[]).filter((c: any) => c.key === 'PRI').map((c: any) => c.name);
|
||||
setPkColumns(pks);
|
||||
})
|
||||
.catch(() => {
|
||||
if (pkSeqRef.current !== pkSeq) return;
|
||||
if (pkKeyRef.current !== pkKey) return;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let totalRecords = 0;
|
||||
if (resCount.success && Array.isArray(resCount.data) && resCount.data.length > 0) {
|
||||
totalRecords = Number(resCount.data[0]['total']);
|
||||
}
|
||||
|
||||
if (resData.success) {
|
||||
let resultData = resData.data as any[];
|
||||
if (!Array.isArray(resultData)) resultData = [];
|
||||
|
||||
const hasMore = resultData.length > size;
|
||||
if (hasMore) resultData = resultData.slice(0, size);
|
||||
|
||||
let fieldNames = resData.fields || [];
|
||||
if (fieldNames.length === 0 && resultData.length > 0) {
|
||||
fieldNames = Object.keys(resultData[0]);
|
||||
}
|
||||
if (fetchSeqRef.current !== seq) return;
|
||||
setColumnNames(fieldNames);
|
||||
|
||||
setData(resultData.map((row: any, i: number) => ({ ...row, key: `row-${i}` })));
|
||||
|
||||
setPagination(prev => ({ ...prev, current: page, pageSize: size, total: totalRecords }));
|
||||
resultData.forEach((row: any, i: number) => {
|
||||
if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = `row-${offset + i}`;
|
||||
});
|
||||
setData(resultData);
|
||||
const countKey = `${tab.connectionId}|${dbName}|${tableName}|${whereSQL}`;
|
||||
const derivedTotalKnown = !hasMore;
|
||||
const derivedTotal = derivedTotalKnown ? offset + resultData.length : page * size + 1;
|
||||
if (derivedTotalKnown) countKeyRef.current = countKey;
|
||||
|
||||
setPagination(prev => {
|
||||
if (derivedTotalKnown) {
|
||||
return { ...prev, current: page, pageSize: size, total: derivedTotal, totalKnown: true };
|
||||
}
|
||||
if (prev.totalKnown && countKeyRef.current === countKey) {
|
||||
return { ...prev, current: page, pageSize: size };
|
||||
}
|
||||
return { ...prev, current: page, pageSize: size, total: derivedTotal, totalKnown: false };
|
||||
});
|
||||
|
||||
if (!derivedTotalKnown) {
|
||||
if (countKeyRef.current !== countKey) {
|
||||
countKeyRef.current = countKey;
|
||||
const countSeq = ++countSeqRef.current;
|
||||
const countStart = Date.now();
|
||||
|
||||
DBQuery(config as any, dbName, countSql)
|
||||
.then((resCount: any) => {
|
||||
const countDuration = Date.now() - countStart;
|
||||
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-count`,
|
||||
timestamp: Date.now(),
|
||||
sql: countSql,
|
||||
status: resCount.success ? 'success' : 'error',
|
||||
duration: countDuration,
|
||||
message: resCount.success ? '' : resCount.message,
|
||||
dbName
|
||||
});
|
||||
|
||||
if (countSeqRef.current !== countSeq) return;
|
||||
if (countKeyRef.current !== countKey) return;
|
||||
|
||||
if (!resCount.success) return;
|
||||
if (!Array.isArray(resCount.data) || resCount.data.length === 0) return;
|
||||
|
||||
const total = Number(resCount.data[0]?.['total']);
|
||||
if (!Number.isFinite(total) || total < 0) return;
|
||||
|
||||
setPagination(prev => ({ ...prev, total, totalKnown: true }));
|
||||
})
|
||||
.catch(() => {
|
||||
if (countSeqRef.current !== countSeq) return;
|
||||
if (countKeyRef.current !== countKey) return;
|
||||
// 统计失败不影响主流程,不弹窗;可在日志里查看。
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
message.error(resData.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
if (fetchSeqRef.current !== seq) return;
|
||||
message.error("Error fetching data: " + e.message);
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-error`,
|
||||
@@ -155,7 +198,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
dbName
|
||||
});
|
||||
}
|
||||
setLoading(false);
|
||||
if (fetchSeqRef.current === seq) setLoading(false);
|
||||
}, [connections, tab, sortInfo, filterConditions, pkColumns.length]);
|
||||
// Depend on pkColumns.length to avoid loop? No, pkColumns is updated inside.
|
||||
// Actually, 'pkColumns' state shouldn't trigger re-fetch.
|
||||
@@ -165,7 +208,10 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
// So it's fine.
|
||||
|
||||
// Handlers memoized
|
||||
const handleReload = useCallback(() => fetchData(), [fetchData]);
|
||||
const handleReload = useCallback(() => {
|
||||
countKeyRef.current = '';
|
||||
fetchData(pagination.current, pagination.pageSize);
|
||||
}, [fetchData, pagination.current, pagination.pageSize]);
|
||||
const handleSort = useCallback((field: string, order: string) => setSortInfo({ columnKey: field, order }), []);
|
||||
const handlePageChange = useCallback((page: number, size: number) => fetchData(page, size), [fetchData]);
|
||||
const handleToggleFilter = useCallback(() => setShowFilter(prev => !prev), []);
|
||||
@@ -176,7 +222,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}, [tab, sortInfo, filterConditions]); // Initial load and re-load on sort/filter
|
||||
|
||||
return (
|
||||
<div style={{ height: '100%', width: '100%', overflow: 'hidden' }}>
|
||||
<div style={{ flex: '1 1 auto', minHeight: 0, height: '100%', width: '100%', overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
<DataGrid
|
||||
data={data}
|
||||
columnNames={columnNames}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,6 +47,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
|
||||
const [autoExpandParent, setAutoExpandParent] = useState(true);
|
||||
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
|
||||
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
|
||||
const [selectedNodes, setSelectedNodes] = useState<any[]>([]);
|
||||
const [contextMenu, setContextMenu] = useState<{ x: number, y: number, items: MenuProps['items'] } | null>(null);
|
||||
|
||||
// Virtual Scroll State
|
||||
@@ -283,10 +285,14 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
|
||||
const onSelect = (keys: React.Key[], info: any) => {
|
||||
if (!info.node.selected) {
|
||||
setSelectedKeys(keys);
|
||||
setSelectedNodes(info.selectedNodes || []);
|
||||
|
||||
if (keys.length === 0) {
|
||||
setActiveContext(null);
|
||||
return;
|
||||
}
|
||||
if (!info.selected) return;
|
||||
|
||||
const { type, dataRef, key, title } = info.node;
|
||||
|
||||
@@ -313,15 +319,6 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
|
||||
const onDoubleClick = (e: any, node: any) => {
|
||||
const key = node.key;
|
||||
const isExpanded = expandedKeys.includes(key);
|
||||
const newExpandedKeys = isExpanded
|
||||
? expandedKeys.filter(k => k !== key)
|
||||
: [...expandedKeys, key];
|
||||
|
||||
setExpandedKeys(newExpandedKeys);
|
||||
if (!isExpanded) setAutoExpandParent(false);
|
||||
|
||||
if (node.type === 'table') {
|
||||
const { tableName, dbName, id } = node.dataRef;
|
||||
addTab({
|
||||
@@ -332,6 +329,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
dbName,
|
||||
tableName,
|
||||
});
|
||||
return;
|
||||
} else if (node.type === 'saved-query') {
|
||||
const q = node.dataRef;
|
||||
addTab({
|
||||
@@ -342,7 +340,17 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
dbName: q.dbName,
|
||||
query: q.sql
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const key = node.key;
|
||||
const isExpanded = expandedKeys.includes(key);
|
||||
const newExpandedKeys = isExpanded
|
||||
? expandedKeys.filter(k => k !== key)
|
||||
: [...expandedKeys, key];
|
||||
|
||||
setExpandedKeys(newExpandedKeys);
|
||||
if (!isExpanded) setAutoExpandParent(false);
|
||||
};
|
||||
|
||||
const handleCopyStructure = async (node: any) => {
|
||||
@@ -382,6 +390,60 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
};
|
||||
|
||||
const normalizeConnConfig = (raw: any) => ({
|
||||
...raw,
|
||||
port: Number(raw.port),
|
||||
password: raw.password || "",
|
||||
database: raw.database || "",
|
||||
useSSH: raw.useSSH || false,
|
||||
ssh: raw.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
});
|
||||
|
||||
const handleExportDatabaseSQL = async (node: any, includeData: boolean) => {
|
||||
const conn = node.dataRef;
|
||||
const dbName = conn.dbName || node.title;
|
||||
const hide = message.loading(includeData ? `正在备份数据库 ${dbName} (结构+数据)...` : `正在导出数据库 ${dbName} 表结构...`, 0);
|
||||
try {
|
||||
const res = await (window as any).go.app.App.ExportDatabaseSQL(normalizeConnConfig(conn.config), dbName, includeData);
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
hide();
|
||||
message.error('导出失败: ' + (e?.message || String(e)));
|
||||
}
|
||||
};
|
||||
|
||||
const handleExportTablesSQL = async (nodes: any[], includeData: boolean) => {
|
||||
if (!nodes || nodes.length === 0) return;
|
||||
const first = nodes[0].dataRef;
|
||||
const dbName = first.dbName;
|
||||
const connId = first.id;
|
||||
const allSame = nodes.every(n => n?.dataRef?.id === connId && n?.dataRef?.dbName === dbName);
|
||||
if (!allSame) {
|
||||
message.error('请在同一连接、同一数据库下选择多张表进行导出');
|
||||
return;
|
||||
}
|
||||
|
||||
const tableNames = nodes.map(n => n.dataRef.tableName).filter(Boolean);
|
||||
const hide = message.loading(includeData ? `正在备份选中表 (${tableNames.length})...` : `正在导出选中表结构 (${tableNames.length})...`, 0);
|
||||
try {
|
||||
const res = await (window as any).go.app.App.ExportTablesSQL(normalizeConnConfig(first.config), dbName, tableNames, includeData);
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
} else if (res.message !== 'Cancelled') {
|
||||
message.error('导出失败: ' + res.message);
|
||||
}
|
||||
} catch (e: any) {
|
||||
hide();
|
||||
message.error('导出失败: ' + (e?.message || String(e)));
|
||||
}
|
||||
};
|
||||
|
||||
const handleRunSQLFile = async (node: any) => {
|
||||
const res = await (window as any).go.app.App.OpenSQLFile();
|
||||
if (res.success) {
|
||||
@@ -550,6 +612,18 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
icon: <ReloadOutlined />,
|
||||
onClick: () => loadTables(node)
|
||||
},
|
||||
{
|
||||
key: 'export-db-schema',
|
||||
label: '导出全部表结构 (SQL)',
|
||||
icon: <ExportOutlined />,
|
||||
onClick: () => handleExportDatabaseSQL(node, false)
|
||||
},
|
||||
{
|
||||
key: 'backup-db-sql',
|
||||
label: '备份全部表 (结构+数据 SQL)',
|
||||
icon: <SaveOutlined />,
|
||||
onClick: () => handleExportDatabaseSQL(node, true)
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'disconnect-db',
|
||||
@@ -588,7 +662,25 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
];
|
||||
} else if (node.type === 'table') {
|
||||
const sameContextSelectedTables = (selectedNodes || []).filter((n: any) => n?.type === 'table' && n?.dataRef?.id === node?.dataRef?.id && n?.dataRef?.dbName === node?.dataRef?.dbName);
|
||||
const selectedForAction = sameContextSelectedTables.some((n: any) => n?.key === node.key) ? sameContextSelectedTables : [node];
|
||||
|
||||
return [
|
||||
...(selectedForAction.length > 1 ? ([
|
||||
{
|
||||
key: 'export-selected-schema',
|
||||
label: `导出选中表结构 (${selectedForAction.length}) (SQL)`,
|
||||
icon: <ExportOutlined />,
|
||||
onClick: () => handleExportTablesSQL(selectedForAction, false)
|
||||
},
|
||||
{
|
||||
key: 'backup-selected-sql',
|
||||
label: `备份选中表 (${selectedForAction.length}) (结构+数据 SQL)`,
|
||||
icon: <SaveOutlined />,
|
||||
onClick: () => handleExportTablesSQL(selectedForAction, true)
|
||||
},
|
||||
{ type: 'divider' as const }
|
||||
]) : []),
|
||||
{
|
||||
key: 'new-query',
|
||||
label: '新建查询',
|
||||
@@ -684,6 +776,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
loadedKeys={loadedKeys}
|
||||
onLoad={setLoadedKeys}
|
||||
autoExpandParent={autoExpandParent}
|
||||
multiple
|
||||
selectedKeys={selectedKeys}
|
||||
blockNode
|
||||
height={treeHeight}
|
||||
onRightClick={onRightClick}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { Tabs, Button } from 'antd';
|
||||
import { Tabs, Dropdown } from 'antd';
|
||||
import type { MenuProps } from 'antd';
|
||||
import { useStore } from '../store';
|
||||
import DataViewer from './DataViewer';
|
||||
import QueryEditor from './QueryEditor';
|
||||
import TableDesigner from './TableDesigner';
|
||||
|
||||
const TabManager: React.FC = () => {
|
||||
const { tabs, activeTabId, setActiveTab, closeTab } = useStore();
|
||||
const { tabs, activeTabId, setActiveTab, closeTab, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs } = useStore();
|
||||
|
||||
const onChange = (newActiveKey: string) => {
|
||||
setActiveTab(newActiveKey);
|
||||
@@ -18,7 +19,7 @@ const TabManager: React.FC = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const items = useMemo(() => tabs.map(tab => {
|
||||
const items = useMemo(() => tabs.map((tab, index) => {
|
||||
let content;
|
||||
if (tab.type === 'query') {
|
||||
content = <QueryEditor tab={tab} />;
|
||||
@@ -27,27 +28,95 @@ const TabManager: React.FC = () => {
|
||||
} else if (tab.type === 'design') {
|
||||
content = <TableDesigner tab={tab} />;
|
||||
}
|
||||
|
||||
const menuItems: MenuProps['items'] = [
|
||||
{
|
||||
key: 'close-other',
|
||||
label: '关闭其他页',
|
||||
disabled: tabs.length <= 1,
|
||||
onClick: () => closeOtherTabs(tab.id),
|
||||
},
|
||||
{
|
||||
key: 'close-left',
|
||||
label: '关闭左侧',
|
||||
disabled: index === 0,
|
||||
onClick: () => closeTabsToLeft(tab.id),
|
||||
},
|
||||
{
|
||||
key: 'close-right',
|
||||
label: '关闭右侧',
|
||||
disabled: index === tabs.length - 1,
|
||||
onClick: () => closeTabsToRight(tab.id),
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'close-all',
|
||||
label: '关闭所有',
|
||||
disabled: tabs.length === 0,
|
||||
onClick: () => closeAllTabs(),
|
||||
},
|
||||
];
|
||||
|
||||
return {
|
||||
label: tab.title,
|
||||
label: (
|
||||
<Dropdown menu={{ items: menuItems }} trigger={['contextMenu']}>
|
||||
<span onContextMenu={(e) => e.preventDefault()}>{tab.title}</span>
|
||||
</Dropdown>
|
||||
),
|
||||
key: tab.id,
|
||||
children: content,
|
||||
};
|
||||
}), [tabs]);
|
||||
}), [tabs, closeOtherTabs, closeTabsToLeft, closeTabsToRight, closeAllTabs]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<style>{`
|
||||
.ant-tabs-content { height: 100%; }
|
||||
.ant-tabs-tabpane { height: 100%; }
|
||||
.main-tabs {
|
||||
height: 100%;
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
.main-tabs .ant-tabs-nav {
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
.main-tabs .ant-tabs-content-holder {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.main-tabs .ant-tabs-content {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
.main-tabs .ant-tabs-tabpane {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
.main-tabs .ant-tabs-tabpane > div {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
}
|
||||
.main-tabs .ant-tabs-tabpane-hidden {
|
||||
display: none !important;
|
||||
}
|
||||
`}</style>
|
||||
<Tabs
|
||||
className="main-tabs"
|
||||
type="editable-card"
|
||||
onChange={onChange}
|
||||
activeKey={activeTabId || undefined}
|
||||
onEdit={onEdit}
|
||||
items={items}
|
||||
style={{ height: '100%' }}
|
||||
hideAdd
|
||||
/>
|
||||
</>
|
||||
|
||||
@@ -550,7 +550,6 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
<div ref={containerRef} className="table-designer-wrapper" style={{ height: '100%', overflow: 'hidden', position: 'relative' }}>
|
||||
<style>{`
|
||||
.table-designer-wrapper .ant-table-body {
|
||||
height: ${tableHeight}px !important;
|
||||
max-height: ${tableHeight}px !important;
|
||||
}
|
||||
`}</style>
|
||||
|
||||
@@ -21,6 +21,7 @@ interface AppState {
|
||||
savedQueries: SavedQuery[];
|
||||
darkMode: boolean;
|
||||
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
|
||||
queryOptions: { maxRows: number };
|
||||
sqlLogs: SqlLog[];
|
||||
|
||||
addConnection: (conn: SavedConnection) => void;
|
||||
@@ -29,6 +30,10 @@ interface AppState {
|
||||
|
||||
addTab: (tab: TabData) => void;
|
||||
closeTab: (id: string) => void;
|
||||
closeOtherTabs: (id: string) => void;
|
||||
closeTabsToLeft: (id: string) => void;
|
||||
closeTabsToRight: (id: string) => void;
|
||||
closeAllTabs: () => void;
|
||||
setActiveTab: (id: string) => void;
|
||||
setActiveContext: (context: { connectionId: string; dbName: string } | null) => void;
|
||||
|
||||
@@ -37,6 +42,7 @@ interface AppState {
|
||||
|
||||
toggleDarkMode: () => void;
|
||||
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
|
||||
setQueryOptions: (options: Partial<{ maxRows: number }>) => void;
|
||||
|
||||
addSqlLog: (log: SqlLog) => void;
|
||||
clearSqlLogs: () => void;
|
||||
@@ -52,6 +58,7 @@ export const useStore = create<AppState>()(
|
||||
savedQueries: [],
|
||||
darkMode: false,
|
||||
sqlFormatOptions: { keywordCase: 'upper' },
|
||||
queryOptions: { maxRows: 5000 },
|
||||
sqlLogs: [],
|
||||
|
||||
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
|
||||
@@ -79,6 +86,30 @@ export const useStore = create<AppState>()(
|
||||
}
|
||||
return { tabs: newTabs, activeTabId: newActiveId };
|
||||
}),
|
||||
|
||||
closeOtherTabs: (id) => set((state) => {
|
||||
const keep = state.tabs.find(t => t.id === id);
|
||||
if (!keep) return state;
|
||||
return { tabs: [keep], activeTabId: id };
|
||||
}),
|
||||
|
||||
closeTabsToLeft: (id) => set((state) => {
|
||||
const index = state.tabs.findIndex(t => t.id === id);
|
||||
if (index === -1) return state;
|
||||
const newTabs = state.tabs.slice(index);
|
||||
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
|
||||
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
|
||||
}),
|
||||
|
||||
closeTabsToRight: (id) => set((state) => {
|
||||
const index = state.tabs.findIndex(t => t.id === id);
|
||||
if (index === -1) return state;
|
||||
const newTabs = state.tabs.slice(0, index + 1);
|
||||
const activeStillExists = state.activeTabId ? newTabs.some(t => t.id === state.activeTabId) : false;
|
||||
return { tabs: newTabs, activeTabId: activeStillExists ? state.activeTabId : id };
|
||||
}),
|
||||
|
||||
closeAllTabs: () => set(() => ({ tabs: [], activeTabId: null })),
|
||||
|
||||
setActiveTab: (id) => set({ activeTabId: id }),
|
||||
setActiveContext: (context) => set({ activeContext: context }),
|
||||
@@ -96,13 +127,14 @@ export const useStore = create<AppState>()(
|
||||
|
||||
toggleDarkMode: () => set((state) => ({ darkMode: !state.darkMode })),
|
||||
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
|
||||
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
|
||||
|
||||
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
|
||||
clearSqlLogs: () => set({ sqlLogs: [] }),
|
||||
}),
|
||||
{
|
||||
name: 'lite-db-storage', // name of the item in the storage (must be unique)
|
||||
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions }), // Don't persist logs
|
||||
partialize: (state) => ({ connections: state.connections, savedQueries: state.savedQueries, darkMode: state.darkMode, sqlFormatOptions: state.sqlFormatOptions, queryOptions: state.queryOptions }), // Don't persist logs
|
||||
}
|
||||
)
|
||||
);
|
||||
);
|
||||
|
||||
173
frontend/src/utils/sql.ts
Normal file
173
frontend/src/utils/sql.ts
Normal file
@@ -0,0 +1,173 @@
|
||||
export type FilterCondition = {
|
||||
id?: number;
|
||||
column?: string;
|
||||
op?: string;
|
||||
value?: string;
|
||||
value2?: string;
|
||||
};
|
||||
|
||||
const normalizeIdentPart = (ident: string) => {
|
||||
let raw = (ident || '').trim();
|
||||
if (!raw) return raw;
|
||||
const first = raw[0];
|
||||
const last = raw[raw.length - 1];
|
||||
if ((first === '"' && last === '"') || (first === '`' && last === '`')) {
|
||||
raw = raw.slice(1, -1).trim();
|
||||
}
|
||||
raw = raw.replace(/["`]/g, '').trim();
|
||||
return raw;
|
||||
};
|
||||
|
||||
export const quoteIdentPart = (dbType: string, ident: string) => {
|
||||
const raw = normalizeIdentPart(ident);
|
||||
if (!raw) return raw;
|
||||
if ((dbType || '').toLowerCase() === 'mysql') return `\`${raw.replace(/`/g, '``')}\``;
|
||||
return `"${raw.replace(/"/g, '""')}"`;
|
||||
};
|
||||
|
||||
export const quoteQualifiedIdent = (dbType: string, ident: string) => {
|
||||
const raw = (ident || '').trim();
|
||||
if (!raw) return raw;
|
||||
const parts = raw.split('.').map(normalizeIdentPart).filter(Boolean);
|
||||
if (parts.length <= 1) return quoteIdentPart(dbType, raw);
|
||||
return parts.map(p => quoteIdentPart(dbType, p)).join('.');
|
||||
};
|
||||
|
||||
export const escapeLiteral = (val: string) => (val || '').replace(/'/g, "''");
|
||||
|
||||
export const parseListValues = (val: string) => {
|
||||
const raw = (val || '').trim();
|
||||
if (!raw) return [];
|
||||
return raw
|
||||
.split(/[\n,,]+/)
|
||||
.map(s => s.trim())
|
||||
.filter(Boolean);
|
||||
};
|
||||
|
||||
export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) => {
|
||||
const whereParts: string[] = [];
|
||||
|
||||
(conditions || []).forEach((cond) => {
|
||||
const op = (cond?.op || '').trim();
|
||||
const column = (cond?.column || '').trim();
|
||||
const value = (cond?.value ?? '').toString();
|
||||
const value2 = (cond?.value2 ?? '').toString();
|
||||
|
||||
if (op === 'CUSTOM') {
|
||||
const expr = value.trim();
|
||||
if (expr) whereParts.push(`(${expr})`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!column) return;
|
||||
|
||||
const col = quoteIdentPart(dbType, column);
|
||||
|
||||
switch (op) {
|
||||
case 'IS_NULL':
|
||||
whereParts.push(`${col} IS NULL`);
|
||||
return;
|
||||
case 'IS_NOT_NULL':
|
||||
whereParts.push(`${col} IS NOT NULL`);
|
||||
return;
|
||||
case 'IS_EMPTY':
|
||||
// 兼容:空值通常理解为 NULL 或空字符串
|
||||
whereParts.push(`(${col} IS NULL OR ${col} = '')`);
|
||||
return;
|
||||
case 'IS_NOT_EMPTY':
|
||||
whereParts.push(`(${col} IS NOT NULL AND ${col} <> '')`);
|
||||
return;
|
||||
case 'BETWEEN': {
|
||||
const v1 = value.trim();
|
||||
const v2 = value2.trim();
|
||||
if (!v1 || !v2) return;
|
||||
whereParts.push(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_BETWEEN': {
|
||||
const v1 = value.trim();
|
||||
const v2 = value2.trim();
|
||||
if (!v1 || !v2) return;
|
||||
whereParts.push(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
return;
|
||||
}
|
||||
case 'IN': {
|
||||
const items = parseListValues(value);
|
||||
if (items.length === 0) return;
|
||||
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
|
||||
whereParts.push(`${col} IN (${list})`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_IN': {
|
||||
const items = parseListValues(value);
|
||||
if (items.length === 0) return;
|
||||
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
|
||||
whereParts.push(`${col} NOT IN (${list})`);
|
||||
return;
|
||||
}
|
||||
case 'CONTAINS': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_CONTAINS': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'STARTS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_STARTS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'ENDS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_ENDS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
case '=':
|
||||
case '!=':
|
||||
case '<':
|
||||
case '<=':
|
||||
case '>':
|
||||
case '>=': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
// 兼容旧值:LIKE
|
||||
if (op.toUpperCase() === 'LIKE') {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
|
||||
};
|
||||
|
||||
10
frontend/wailsjs/go/app/App.d.ts
vendored
10
frontend/wailsjs/go/app/App.d.ts
vendored
@@ -29,10 +29,20 @@ export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,a
|
||||
|
||||
export function DataSync(arg1:sync.SyncConfig):Promise<sync.SyncResult>;
|
||||
|
||||
export function DataSyncAnalyze(arg1:sync.SyncConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string,arg5:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportTablesSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>,arg4:boolean):Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportConfigFile():Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -54,14 +54,34 @@ export function DataSync(arg1) {
|
||||
return window['go']['app']['App']['DataSync'](arg1);
|
||||
}
|
||||
|
||||
export function DataSyncAnalyze(arg1) {
|
||||
return window['go']['app']['App']['DataSyncAnalyze'](arg1);
|
||||
}
|
||||
|
||||
export function DataSyncPreview(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['DataSyncPreview'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function ExportData(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ExportDatabaseSQL(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['ExportDatabaseSQL'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function ExportQuery(arg1, arg2, arg3, arg4, arg5) {
|
||||
return window['go']['app']['App']['ExportQuery'](arg1, arg2, arg3, arg4, arg5);
|
||||
}
|
||||
|
||||
export function ExportTable(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportTable'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ExportTablesSQL(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExportTablesSQL'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ImportConfigFile() {
|
||||
return window['go']['app']['App']['ImportConfigFile']();
|
||||
}
|
||||
|
||||
@@ -142,11 +142,37 @@ export namespace connection {
|
||||
|
||||
export namespace sync {
|
||||
|
||||
export class TableOptions {
|
||||
insert?: boolean;
|
||||
update?: boolean;
|
||||
delete?: boolean;
|
||||
selectedInsertPks?: string[];
|
||||
selectedUpdatePks?: string[];
|
||||
selectedDeletePks?: string[];
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new TableOptions(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.insert = source["insert"];
|
||||
this.update = source["update"];
|
||||
this.delete = source["delete"];
|
||||
this.selectedInsertPks = source["selectedInsertPks"];
|
||||
this.selectedUpdatePks = source["selectedUpdatePks"];
|
||||
this.selectedDeletePks = source["selectedDeletePks"];
|
||||
}
|
||||
}
|
||||
export class SyncConfig {
|
||||
sourceConfig: connection.ConnectionConfig;
|
||||
targetConfig: connection.ConnectionConfig;
|
||||
tables: string[];
|
||||
content?: string;
|
||||
mode: string;
|
||||
jobId?: string;
|
||||
autoAddColumns?: boolean;
|
||||
tableOptions?: Record<string, TableOptions>;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new SyncConfig(source);
|
||||
@@ -157,7 +183,11 @@ export namespace sync {
|
||||
this.sourceConfig = this.convertValues(source["sourceConfig"], connection.ConnectionConfig);
|
||||
this.targetConfig = this.convertValues(source["targetConfig"], connection.ConnectionConfig);
|
||||
this.tables = source["tables"];
|
||||
this.content = source["content"];
|
||||
this.mode = source["mode"];
|
||||
this.jobId = source["jobId"];
|
||||
this.autoAddColumns = source["autoAddColumns"];
|
||||
this.tableOptions = this.convertValues(source["tableOptions"], TableOptions, true);
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
|
||||
@@ -141,9 +141,6 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
|
||||
if len(shortKey) > 12 {
|
||||
shortKey = shortKey[:12]
|
||||
}
|
||||
if config.UseSSH && config.Type != "mysql" {
|
||||
logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config))
|
||||
}
|
||||
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
|
||||
|
||||
a.mu.Lock()
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
)
|
||||
|
||||
// Generic DB Methods
|
||||
@@ -91,16 +94,39 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
query = sanitizeSQLForPgLike(runConfig.Type, query)
|
||||
timeoutSeconds := runConfig.Timeout
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 30
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
|
||||
defer cancel()
|
||||
|
||||
lowerQuery := strings.TrimSpace(strings.ToLower(query))
|
||||
if strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain") {
|
||||
data, columns, err := dbInst.Query(query)
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, query)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(query)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns}
|
||||
} else {
|
||||
affected, err := dbInst.Exec(query)
|
||||
var affected int64
|
||||
if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, query)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(query)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
@@ -213,12 +218,36 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
format = strings.ToLower(format)
|
||||
if format == "sql" {
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := bufio.NewWriterSize(f, 1024*1024)
|
||||
defer w.Flush()
|
||||
|
||||
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, tableName, true); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if err := writeSQLFooter(w, runConfig); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(runConfig.Type, tableName))
|
||||
|
||||
data, columns, err := dbInst.Query(query)
|
||||
@@ -231,71 +260,129 @@ data, columns, err := dbInst.Query(query)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
|
||||
format = strings.ToLower(format)
|
||||
var csvWriter *csv.Writer
|
||||
var jsonEncoder *json.Encoder
|
||||
var isJsonFirstRow = true
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
f.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
csvWriter = csv.NewWriter(f)
|
||||
defer csvWriter.Flush()
|
||||
if err := csvWriter.Write(columns); err != nil {
|
||||
func (a *App) ExportTablesSQL(config connection.ConnectionConfig, dbName string, tableNames []string, includeData bool) connection.QueryResult {
|
||||
safeDbName := strings.TrimSpace(dbName)
|
||||
if safeDbName == "" {
|
||||
safeDbName = "export"
|
||||
}
|
||||
suffix := "schema"
|
||||
if includeData {
|
||||
suffix = "backup"
|
||||
}
|
||||
defaultFilename := fmt.Sprintf("%s_%s_%dtables.sql", safeDbName, suffix, len(tableNames))
|
||||
if len(tableNames) == 1 && strings.TrimSpace(tableNames[0]) != "" {
|
||||
defaultFilename = fmt.Sprintf("%s_%s.sql", strings.TrimSpace(tableNames[0]), suffix)
|
||||
}
|
||||
|
||||
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
|
||||
Title: "Export Tables (SQL)",
|
||||
DefaultFilename: defaultFilename,
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
tables := make([]string, 0, len(tableNames))
|
||||
seen := make(map[string]struct{}, len(tableNames))
|
||||
for _, t := range tableNames {
|
||||
t = strings.TrimSpace(t)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[t]; ok {
|
||||
continue
|
||||
}
|
||||
seen[t] = struct{}{}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
sort.Strings(tables)
|
||||
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := bufio.NewWriterSize(f, 1024*1024)
|
||||
defer w.Flush()
|
||||
|
||||
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
for _, t := range tables {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
case "json":
|
||||
f.WriteString("[\n")
|
||||
jsonEncoder = json.NewEncoder(f)
|
||||
jsonEncoder.SetIndent(" ", " ")
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
|
||||
seps := make([]string, len(columns))
|
||||
for i := range seps {
|
||||
seps[i] = "---"
|
||||
}
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
|
||||
}
|
||||
if err := writeSQLFooter(w, runConfig); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
for _, rowMap := range data {
|
||||
record := make([]string, len(columns))
|
||||
for i, col := range columns {
|
||||
val := rowMap[col]
|
||||
if val == nil {
|
||||
record[i] = "NULL"
|
||||
} else {
|
||||
s := fmt.Sprintf("%v", val)
|
||||
if format == "md" {
|
||||
s = strings.ReplaceAll(s, "|", "\\|")
|
||||
s = strings.ReplaceAll(s, "\n", "<br>")
|
||||
}
|
||||
record[i] = s
|
||||
}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
if err := csvWriter.Write(record); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
case "json":
|
||||
if !isJsonFirstRow {
|
||||
f.WriteString(",\n")
|
||||
}
|
||||
if err := jsonEncoder.Encode(rowMap); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
isJsonFirstRow = false
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
|
||||
}
|
||||
func (a *App) ExportDatabaseSQL(config connection.ConnectionConfig, dbName string, includeData bool) connection.QueryResult {
|
||||
safeDbName := strings.TrimSpace(dbName)
|
||||
if safeDbName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "dbName required"}
|
||||
}
|
||||
suffix := "schema"
|
||||
if includeData {
|
||||
suffix = "backup"
|
||||
}
|
||||
|
||||
if format == "json" {
|
||||
f.WriteString("\n]")
|
||||
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
|
||||
Title: fmt.Sprintf("Export %s (SQL)", safeDbName),
|
||||
DefaultFilename: fmt.Sprintf("%s_%s.sql", safeDbName, suffix),
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
tables, err := dbInst.GetTables(dbName)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
sort.Strings(tables)
|
||||
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := bufio.NewWriterSize(f, 1024*1024)
|
||||
defer w.Flush()
|
||||
|
||||
if err := writeSQLHeader(w, runConfig, dbName); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
for _, t := range tables {
|
||||
if err := dumpTableSQL(w, dbInst, runConfig, dbName, t, includeData); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
}
|
||||
if err := writeSQLFooter(w, runConfig); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
@@ -340,6 +427,173 @@ func quoteQualifiedIdentByType(dbType string, ident string) string {
|
||||
return strings.Join(quotedParts, ".")
|
||||
}
|
||||
|
||||
func writeSQLHeader(w *bufio.Writer, config connection.ConnectionConfig, dbName string) error {
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
if _, err := w.WriteString(fmt.Sprintf("-- GoNavi SQL Export\n-- Time: %s\n", now)); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(dbName) != "" {
|
||||
if _, err := w.WriteString(fmt.Sprintf("-- Database: %s\n\n", dbName)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" && strings.TrimSpace(dbName) != "" {
|
||||
if _, err := w.WriteString(fmt.Sprintf("USE %s;\n\n", quoteIdentByType("mysql", dbName))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("SET FOREIGN_KEY_CHECKS=0;\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeSQLFooter(w *bufio.Writer, config connection.ConnectionConfig) error {
|
||||
if strings.ToLower(strings.TrimSpace(config.Type)) == "mysql" {
|
||||
if _, err := w.WriteString("\nSET FOREIGN_KEY_CHECKS=1;\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func qualifyTable(schemaName, tableName string) string {
|
||||
schemaName = strings.TrimSpace(schemaName)
|
||||
tableName = strings.TrimSpace(tableName)
|
||||
if schemaName == "" {
|
||||
return tableName
|
||||
}
|
||||
return schemaName + "." + tableName
|
||||
}
|
||||
|
||||
func ensureSQLTerminator(sql string) string {
|
||||
trimmed := strings.TrimSpace(sql)
|
||||
if trimmed == "" {
|
||||
return sql
|
||||
}
|
||||
if strings.HasSuffix(trimmed, ";") {
|
||||
return sql
|
||||
}
|
||||
return sql + ";"
|
||||
}
|
||||
|
||||
func isMySQLHexLiteral(s string) bool {
|
||||
if len(s) < 3 || !(strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X")) {
|
||||
return false
|
||||
}
|
||||
for i := 2; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatSQLValue(dbType string, v interface{}) string {
|
||||
if v == nil {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
if val {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case int:
|
||||
return strconv.Itoa(val)
|
||||
case int8, int16, int32, int64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case float32:
|
||||
f := float64(val)
|
||||
if math.IsNaN(f) || math.IsInf(f, 0) {
|
||||
return "NULL"
|
||||
}
|
||||
return strconv.FormatFloat(f, 'f', -1, 32)
|
||||
case float64:
|
||||
if math.IsNaN(val) || math.IsInf(val, 0) {
|
||||
return "NULL"
|
||||
}
|
||||
return strconv.FormatFloat(val, 'f', -1, 64)
|
||||
case time.Time:
|
||||
return "'" + val.Format("2006-01-02 15:04:05") + "'"
|
||||
case string:
|
||||
if strings.ToLower(strings.TrimSpace(dbType)) == "mysql" && isMySQLHexLiteral(val) {
|
||||
return val
|
||||
}
|
||||
escaped := strings.ReplaceAll(val, "'", "''")
|
||||
return "'" + escaped + "'"
|
||||
default:
|
||||
escaped := strings.ReplaceAll(fmt.Sprintf("%v", v), "'", "''")
|
||||
return "'" + escaped + "'"
|
||||
}
|
||||
}
|
||||
|
||||
func dumpTableSQL(w *bufio.Writer, dbInst db.Database, config connection.ConnectionConfig, dbName, tableName string, includeData bool) error {
|
||||
schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
|
||||
|
||||
if _, err := w.WriteString("\n-- ----------------------------\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString(fmt.Sprintf("-- Table: %s\n", qualifyTable(schemaName, pureTableName))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("-- ----------------------------\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createSQL, err := dbInst.GetCreateStatement(schemaName, pureTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString(ensureSQLTerminator(createSQL)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.WriteString("\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !includeData {
|
||||
return nil
|
||||
}
|
||||
|
||||
qualified := qualifyTable(schemaName, pureTableName)
|
||||
selectSQL := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.Type, qualified))
|
||||
data, columns, err := dbInst.Query(selectSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
if _, err := w.WriteString("-- (0 rows)\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
quotedCols := make([]string, 0, len(columns))
|
||||
for _, c := range columns {
|
||||
quotedCols = append(quotedCols, quoteIdentByType(config.Type, c))
|
||||
}
|
||||
quotedTable := quoteQualifiedIdentByType(config.Type, qualified)
|
||||
|
||||
for _, row := range data {
|
||||
values := make([]string, 0, len(columns))
|
||||
for _, c := range columns {
|
||||
values = append(values, formatSQLValue(config.Type, row[c]))
|
||||
}
|
||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", quotedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportData exports provided data to a file
|
||||
func (a *App) ExportData(data []map[string]interface{}, columns []string, defaultName string, format string) connection.QueryResult {
|
||||
if defaultName == "" {
|
||||
@@ -359,33 +613,101 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
// ExportQuery exports by executing the provided SELECT query on backend side.
|
||||
// This avoids frontend IPC payload limits when exporting very large/long-text columns (e.g. base64).
|
||||
func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, query string, defaultName string, format string) connection.QueryResult {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return connection.QueryResult{Success: false, Message: "query required"}
|
||||
}
|
||||
|
||||
if defaultName == "" {
|
||||
defaultName = "export"
|
||||
}
|
||||
|
||||
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
|
||||
Title: "Export Query Result",
|
||||
DefaultFilename: fmt.Sprintf("%s.%s", defaultName, strings.ToLower(format)),
|
||||
})
|
||||
if err != nil || filename == "" {
|
||||
return connection.QueryResult{Success: false, Message: "Cancelled"}
|
||||
}
|
||||
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
query = sanitizeSQLForPgLike(runConfig.Type, query)
|
||||
lowerQuery := strings.ToLower(strings.TrimSpace(query))
|
||||
if !(strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "with")) {
|
||||
return connection.QueryResult{Success: false, Message: "Only SELECT/WITH queries are supported"}
|
||||
}
|
||||
|
||||
data, columns, err := dbInst.Query(query)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := writeRowsToFile(f, data, columns, format); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
|
||||
func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string, format string) error {
|
||||
format = strings.ToLower(strings.TrimSpace(format))
|
||||
if f == nil {
|
||||
return fmt.Errorf("file required")
|
||||
}
|
||||
|
||||
format = strings.ToLower(format)
|
||||
var csvWriter *csv.Writer
|
||||
var jsonEncoder *json.Encoder
|
||||
var isJsonFirstRow = true
|
||||
isJsonFirstRow := true
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
f.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
if _, err := f.Write([]byte{0xEF, 0xBB, 0xBF}); err != nil {
|
||||
return err
|
||||
}
|
||||
csvWriter = csv.NewWriter(f)
|
||||
defer csvWriter.Flush()
|
||||
if err := csvWriter.Write(columns); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
return err
|
||||
}
|
||||
case "json":
|
||||
f.WriteString("[\n")
|
||||
if _, err := f.WriteString("[\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
jsonEncoder = json.NewEncoder(f)
|
||||
jsonEncoder.SetIndent(" ", " ")
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | "))
|
||||
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | ")); err != nil {
|
||||
return err
|
||||
}
|
||||
seps := make([]string, len(columns))
|
||||
for i := range seps {
|
||||
seps[i] = "---"
|
||||
}
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | "))
|
||||
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | ")); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: "Unsupported format: " + format}
|
||||
return fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
|
||||
for _, rowMap := range data {
|
||||
@@ -394,37 +716,51 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
val := rowMap[col]
|
||||
if val == nil {
|
||||
record[i] = "NULL"
|
||||
} else {
|
||||
s := fmt.Sprintf("%v", val)
|
||||
if format == "md" {
|
||||
s = strings.ReplaceAll(s, "|", "\\|")
|
||||
s = strings.ReplaceAll(s, "\n", "<br>")
|
||||
}
|
||||
record[i] = s
|
||||
continue
|
||||
}
|
||||
|
||||
s := fmt.Sprintf("%v", val)
|
||||
if format == "md" {
|
||||
s = strings.ReplaceAll(s, "|", "\\|")
|
||||
s = strings.ReplaceAll(s, "\n", "<br>")
|
||||
}
|
||||
record[i] = s
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv", "xlsx":
|
||||
if err := csvWriter.Write(record); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
return err
|
||||
}
|
||||
case "json":
|
||||
if !isJsonFirstRow {
|
||||
f.WriteString(",\n")
|
||||
if _, err := f.WriteString(",\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := jsonEncoder.Encode(rowMap); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
|
||||
return err
|
||||
}
|
||||
isJsonFirstRow = false
|
||||
case "md":
|
||||
fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | "))
|
||||
if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | ")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if format == "csv" || format == "xlsx" {
|
||||
csvWriter.Flush()
|
||||
if err := csvWriter.Error(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if format == "json" {
|
||||
f.WriteString("\n]")
|
||||
if _, err := f.WriteString("\n]"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,99 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/sync"
|
||||
|
||||
"github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
|
||||
// DataSync executes a data synchronization task
|
||||
func (a *App) DataSync(config sync.SyncConfig) sync.SyncResult {
|
||||
engine := sync.NewSyncEngine()
|
||||
return engine.RunSync(config)
|
||||
jobID := strings.TrimSpace(config.JobID)
|
||||
if jobID == "" {
|
||||
jobID = fmt.Sprintf("sync-%d", time.Now().UnixNano())
|
||||
config.JobID = jobID
|
||||
}
|
||||
|
||||
reporter := sync.Reporter{
|
||||
OnLog: func(event sync.SyncLogEvent) {
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncLog, event)
|
||||
},
|
||||
OnProgress: func(event sync.SyncProgressEvent) {
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncProgress, event)
|
||||
},
|
||||
}
|
||||
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncStart, map[string]any{
|
||||
"jobId": jobID,
|
||||
"total": len(config.Tables),
|
||||
})
|
||||
|
||||
engine := sync.NewSyncEngine(reporter)
|
||||
res := engine.RunSync(config)
|
||||
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncDone, map[string]any{
|
||||
"jobId": jobID,
|
||||
"result": res,
|
||||
})
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// DataSyncAnalyze analyzes differences between source and target for the given tables (dry-run).
|
||||
func (a *App) DataSyncAnalyze(config sync.SyncConfig) connection.QueryResult {
|
||||
jobID := strings.TrimSpace(config.JobID)
|
||||
if jobID == "" {
|
||||
jobID = fmt.Sprintf("analyze-%d", time.Now().UnixNano())
|
||||
config.JobID = jobID
|
||||
}
|
||||
|
||||
reporter := sync.Reporter{
|
||||
OnLog: func(event sync.SyncLogEvent) {
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncLog, event)
|
||||
},
|
||||
OnProgress: func(event sync.SyncProgressEvent) {
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncProgress, event)
|
||||
},
|
||||
}
|
||||
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncStart, map[string]any{
|
||||
"jobId": jobID,
|
||||
"total": len(config.Tables),
|
||||
"type": "analyze",
|
||||
})
|
||||
|
||||
engine := sync.NewSyncEngine(reporter)
|
||||
res := engine.Analyze(config)
|
||||
|
||||
runtime.EventsEmit(a.ctx, sync.EventSyncDone, map[string]any{
|
||||
"jobId": jobID,
|
||||
"result": res,
|
||||
"type": "analyze",
|
||||
})
|
||||
|
||||
if !res.Success {
|
||||
return connection.QueryResult{Success: false, Message: res.Message, Data: res}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: res.Message, Data: res}
|
||||
}
|
||||
|
||||
// DataSyncPreview returns a limited preview of diff rows for one table.
|
||||
func (a *App) DataSyncPreview(config sync.SyncConfig, tableName string, limit int) connection.QueryResult {
|
||||
jobID := strings.TrimSpace(config.JobID)
|
||||
if jobID == "" {
|
||||
jobID = fmt.Sprintf("preview-%d", time.Now().UnixNano())
|
||||
config.JobID = jobID
|
||||
}
|
||||
|
||||
engine := sync.NewSyncEngine(sync.Reporter{})
|
||||
preview, err := engine.Preview(config, tableName, limit)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "OK", Data: preview}
|
||||
}
|
||||
|
||||
236
internal/app/sql_sanitize.go
Normal file
236
internal/app/sql_sanitize.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func sanitizeSQLForPgLike(dbType string, query string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase":
|
||||
// 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。
|
||||
// 这里做有限次数的迭代,直到输出不再变化。
|
||||
out := query
|
||||
for i := 0; i < 3; i++ {
|
||||
fixed := fixBrokenDoubleDoubleQuotedIdent(out)
|
||||
if fixed == out {
|
||||
break
|
||||
}
|
||||
out = fixed
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return query
|
||||
}
|
||||
}
|
||||
|
||||
// fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like:
|
||||
//
|
||||
// SELECT * FROM ""schema"".""table""
|
||||
//
|
||||
// which can be produced when a quoted identifier gets wrapped by quotes again.
|
||||
//
|
||||
// It is intentionally conservative:
|
||||
// - only runs outside strings/comments/dollar-quoted blocks
|
||||
// - does not touch valid escaped-quote sequences inside quoted identifiers (e.g. "a""b")
|
||||
func fixBrokenDoubleDoubleQuotedIdent(query string) string {
|
||||
if !strings.Contains(query, `""`) {
|
||||
return query
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(query))
|
||||
|
||||
inSingle := false
|
||||
inDoubleIdent := false
|
||||
inLineComment := false
|
||||
inBlockComment := false
|
||||
dollarTag := ""
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
ch := query[i]
|
||||
next := byte(0)
|
||||
if i+1 < len(query) {
|
||||
next = query[i+1]
|
||||
}
|
||||
|
||||
if inLineComment {
|
||||
b.WriteByte(ch)
|
||||
if ch == '\n' {
|
||||
inLineComment = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inBlockComment {
|
||||
b.WriteByte(ch)
|
||||
if ch == '*' && next == '/' {
|
||||
b.WriteByte('/')
|
||||
i++
|
||||
inBlockComment = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if dollarTag != "" {
|
||||
if strings.HasPrefix(query[i:], dollarTag) {
|
||||
b.WriteString(dollarTag)
|
||||
i += len(dollarTag) - 1
|
||||
dollarTag = ""
|
||||
continue
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
continue
|
||||
}
|
||||
if inSingle {
|
||||
b.WriteByte(ch)
|
||||
if ch == '\'' {
|
||||
// escaped single quote
|
||||
if next == '\'' {
|
||||
b.WriteByte('\'')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
inSingle = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inDoubleIdent {
|
||||
b.WriteByte(ch)
|
||||
if ch == '"' {
|
||||
// escaped quote inside identifier
|
||||
if next == '"' {
|
||||
b.WriteByte('"')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
inDoubleIdent = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// --- Outside of all string/comment blocks ---
|
||||
if ch == '-' && next == '-' {
|
||||
b.WriteByte(ch)
|
||||
b.WriteByte('-')
|
||||
i++
|
||||
inLineComment = true
|
||||
continue
|
||||
}
|
||||
if ch == '/' && next == '*' {
|
||||
b.WriteByte(ch)
|
||||
b.WriteByte('*')
|
||||
i++
|
||||
inBlockComment = true
|
||||
continue
|
||||
}
|
||||
if ch == '\'' {
|
||||
b.WriteByte(ch)
|
||||
inSingle = true
|
||||
continue
|
||||
}
|
||||
if ch == '$' {
|
||||
if tag := parseDollarTag(query[i:]); tag != "" {
|
||||
b.WriteString(tag)
|
||||
i += len(tag) - 1
|
||||
dollarTag = tag
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
// Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
|
||||
// Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side).
|
||||
if next == '"' {
|
||||
if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok {
|
||||
b.WriteString(replacement)
|
||||
i = advance - 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteByte(ch)
|
||||
inDoubleIdent = true
|
||||
continue
|
||||
}
|
||||
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) {
|
||||
// start points at the first quote of a broken identifier, usually like:
|
||||
// ""ident"" / ""ident""" / """"ident""""
|
||||
if start < 0 || start+1 >= len(query) {
|
||||
return "", 0, false
|
||||
}
|
||||
if query[start] != '"' || query[start+1] != '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
if start > 0 && query[start-1] == '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
runLen := 0
|
||||
for start+runLen < len(query) && query[start+runLen] == '"' {
|
||||
runLen++
|
||||
}
|
||||
if runLen < 2 || runLen%2 == 1 {
|
||||
// Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes.
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
contentStart := start + runLen
|
||||
j := contentStart
|
||||
for j < len(query) {
|
||||
if query[j] == '"' {
|
||||
endRunLen := 0
|
||||
for j+endRunLen < len(query) && query[j+endRunLen] == '"' {
|
||||
endRunLen++
|
||||
}
|
||||
if endRunLen >= 2 {
|
||||
content := strings.TrimSpace(query[contentStart:j])
|
||||
if looksLikeIdentifierContent(content) {
|
||||
return `"` + content + `"`, j + endRunLen, true
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
}
|
||||
// Fast abort: identifier-like content should not span lines.
|
||||
if query[j] == '\n' || query[j] == '\r' {
|
||||
break
|
||||
}
|
||||
j++
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func looksLikeIdentifierContent(s string) bool {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r == '_' || r == '$' || r == '-' || unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func parseDollarTag(s string) string {
|
||||
// Match: $tag$ where tag is [A-Za-z0-9_]* (can be empty => $$)
|
||||
if len(s) < 2 || s[0] != '$' {
|
||||
return ""
|
||||
}
|
||||
for i := 1; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c == '$' {
|
||||
return s[:i+1]
|
||||
}
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
55
internal/app/sql_sanitize_test.go
Normal file
55
internal/app/sql_sanitize_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package app
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes(t *testing.T) {
|
||||
in := `SELECT * FROM ""ldf_server"".""t_user"" LIMIT 1`
|
||||
out := sanitizeSQLForPgLike("kingbase", in)
|
||||
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithExtraQuotes(t *testing.T) {
|
||||
in := `SELECT * FROM ""ldf_server""".""t_user"" LIMIT 1`
|
||||
out := sanitizeSQLForPgLike("kingbase", in)
|
||||
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithQuadQuotes(t *testing.T) {
|
||||
in := `SELECT * FROM """"ldf_server"""".""t_user"" LIMIT 1`
|
||||
out := sanitizeSQLForPgLike("kingbase", in)
|
||||
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_DoesNotTouchEscapedQuotesInsideIdentifier(t *testing.T) {
|
||||
in := `SELECT "a""b" FROM "t""x"`
|
||||
out := sanitizeSQLForPgLike("postgres", in)
|
||||
if out != in {
|
||||
t.Fatalf("should keep valid escaped quotes inside identifier:\nIN: %s\nOUT: %s", in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_DoesNotTouchDollarQuotedStrings(t *testing.T) {
|
||||
in := "SELECT $$\"\"ldf_server\"\"$$, \"\"ldf_server\"\""
|
||||
out := sanitizeSQLForPgLike("postgres", in)
|
||||
want := "SELECT $$\"\"ldf_server\"\"$$, \"ldf_server\""
|
||||
if out != want {
|
||||
t.Fatalf("unexpected sanitize output for dollar quoted string:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
|
||||
in := `SELECT * FROM ""ldf_server""`
|
||||
out := sanitizeSQLForPgLike("mysql", in)
|
||||
if out != in {
|
||||
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -57,6 +58,20 @@ func (c *CustomDB) Ping() error {
|
||||
return c.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := c.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if c.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -67,45 +82,18 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if c.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := c.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
var v interface{}
|
||||
val := values[i]
|
||||
b, ok := val.([]byte)
|
||||
if ok {
|
||||
if b == nil {
|
||||
v = nil
|
||||
} else {
|
||||
v = string(b)
|
||||
}
|
||||
} else {
|
||||
v = val
|
||||
}
|
||||
entry[col] = v
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (c *CustomDB) Exec(query string) (int64, error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -19,6 +21,7 @@ import (
|
||||
type DamengDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
@@ -26,16 +29,6 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// or dm://user:password@host:port
|
||||
|
||||
address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
|
||||
if config.UseSSH {
|
||||
// SSH logic similar to others, assumes port forwarding
|
||||
_, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
// DM driver likely uses standard net.Dial, so we might need a local listener
|
||||
// or assume port forwarding is handled externally or implicitly via "tcp" override if driver allows.
|
||||
// Similar to Oracle, we skip complex custom dialer injection for now.
|
||||
}
|
||||
}
|
||||
|
||||
escapedPassword := url.PathEscape(config.Password)
|
||||
q := url.Values{}
|
||||
if config.Database != "" {
|
||||
@@ -55,7 +48,42 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := d.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("达梦数据库使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
d.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = d.getDSN(localConfig)
|
||||
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = d.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("dm", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
@@ -69,6 +97,15 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
|
||||
func (d *DamengDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if d.forwarder != nil {
|
||||
if err := d.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭达梦数据库 SSH 端口转发失败:%v", err)
|
||||
}
|
||||
d.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if d.conn != nil {
|
||||
return d.conn.Close()
|
||||
}
|
||||
@@ -88,6 +125,20 @@ func (d *DamengDB) Ping() error {
|
||||
return d.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (d *DamengDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := d.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if d.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -98,45 +149,18 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if d.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := d.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
var v interface{}
|
||||
val := values[i]
|
||||
b, ok := val.([]byte)
|
||||
if ok {
|
||||
if b == nil {
|
||||
v = nil
|
||||
} else {
|
||||
v = string(b)
|
||||
}
|
||||
} else {
|
||||
v = val
|
||||
}
|
||||
entry[col] = v
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *DamengDB) Exec(query string) (int64, error) {
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -16,6 +20,7 @@ import (
|
||||
type KingbaseDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
func quoteConnValue(v string) string {
|
||||
@@ -57,20 +62,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
address := config.Host
|
||||
port := config.Port
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
// Kingbase/Postgres lib/pq allows custom dialer via "host" if using unix socket,
|
||||
// but for custom network it's harder.
|
||||
// Ideally we use a local forwarder.
|
||||
// For now, we assume standard TCP or handle SSH externally.
|
||||
// If we implement the net.Dial override for "kingbase" driver (which might use lib/pq internally),
|
||||
// we might need to check if it supports "cloudsql" style or similar custom dialers.
|
||||
// Similar to others, skipping SSH deep integration here for now.
|
||||
_ = netName
|
||||
}
|
||||
}
|
||||
|
||||
// Construct DSN
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
|
||||
quoteConnValue(address),
|
||||
@@ -85,7 +76,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := k.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
k.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = k.getDSN(localConfig)
|
||||
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = k.getDSN(config)
|
||||
}
|
||||
|
||||
// Open using "kingbase" driver
|
||||
db, err := sql.Open("kingbase", dsn)
|
||||
if err != nil {
|
||||
@@ -100,6 +126,15 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if k.forwarder != nil {
|
||||
if err := k.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
|
||||
}
|
||||
k.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if k.conn != nil {
|
||||
return k.conn.Close()
|
||||
}
|
||||
@@ -119,6 +154,20 @@ func (k *KingbaseDB) Ping() error {
|
||||
return k.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := k.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -129,45 +178,18 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if k.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := k.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
var v interface{}
|
||||
val := values[i]
|
||||
b, ok := val.([]byte)
|
||||
if ok {
|
||||
if b == nil {
|
||||
v = nil
|
||||
} else {
|
||||
v = string(b)
|
||||
}
|
||||
} else {
|
||||
v = val
|
||||
}
|
||||
entry[col] = v
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Exec(query string) (int64, error) {
|
||||
@@ -235,15 +257,84 @@ func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
schema := "public"
|
||||
if dbName != "" {
|
||||
schema = dbName
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '%s' AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, schema, tableName)
|
||||
// 如果仍然没有 schema,使用 current_schema()
|
||||
// 这样可以自动匹配当前连接的 search_path
|
||||
if schema == "" {
|
||||
return k.getColumnsWithCurrentSchema(table)
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
// 移除前后的双引号(如果存在)
|
||||
s = strings.Trim(s, "\"")
|
||||
// 转义单引号
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '%s' AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, esc(schema), esc(table))
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
}
|
||||
|
||||
if row["column_default"] != nil {
|
||||
def := fmt.Sprintf("%v", row["column_default"])
|
||||
col.Default = &def
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
|
||||
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 使用 current_schema() 获取当前schema
|
||||
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = current_schema() AND table_name = '%s'
|
||||
ORDER BY ordinal_position`, esc(table))
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -269,32 +360,76 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
// Postgres/Kingbase index query
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname as index_name,
|
||||
a.attname as column_name,
|
||||
ix.indisunique as is_unique
|
||||
FROM
|
||||
pg_class t,
|
||||
pg_class i,
|
||||
pg_index ix,
|
||||
pg_attribute a,
|
||||
pg_namespace n
|
||||
WHERE
|
||||
t.oid = ix.indrelid
|
||||
AND i.oid = ix.indexrelid
|
||||
AND a.attrelid = t.oid
|
||||
AND a.attnum = ANY(ix.indkey)
|
||||
AND t.relkind = 'r'
|
||||
AND t.relname = '%s'
|
||||
AND n.oid = t.relnamespace
|
||||
AND n.nspname = '%s'
|
||||
`, tableName, "public") // Default to public if dbName (schema) not clear.
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
if dbName != "" {
|
||||
// Update query to use dbName as schema
|
||||
query = strings.Replace(query, "'public'", fmt.Sprintf("'%s'", dbName), 1)
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 构建查询:如果没有指定schema,使用current_schema()
|
||||
var query string
|
||||
if schema != "" {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname as index_name,
|
||||
a.attname as column_name,
|
||||
ix.indisunique as is_unique
|
||||
FROM
|
||||
pg_class t,
|
||||
pg_class i,
|
||||
pg_index ix,
|
||||
pg_attribute a,
|
||||
pg_namespace n
|
||||
WHERE
|
||||
t.oid = ix.indrelid
|
||||
AND i.oid = ix.indexrelid
|
||||
AND a.attrelid = t.oid
|
||||
AND a.attnum = ANY(ix.indkey)
|
||||
AND t.relkind = 'r'
|
||||
AND t.relname = '%s'
|
||||
AND n.oid = t.relnamespace
|
||||
AND n.nspname = '%s'
|
||||
`, esc(table), esc(schema))
|
||||
} else {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname as index_name,
|
||||
a.attname as column_name,
|
||||
ix.indisunique as is_unique
|
||||
FROM
|
||||
pg_class t,
|
||||
pg_class i,
|
||||
pg_index ix,
|
||||
pg_attribute a,
|
||||
pg_namespace n
|
||||
WHERE
|
||||
t.oid = ix.indrelid
|
||||
AND i.oid = ix.indexrelid
|
||||
AND a.attrelid = t.oid
|
||||
AND a.attnum = ANY(ix.indkey)
|
||||
AND t.relkind = 'r'
|
||||
AND t.relname = '%s'
|
||||
AND n.oid = t.relnamespace
|
||||
AND n.nspname = current_schema()
|
||||
`, esc(table))
|
||||
}
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
@@ -323,27 +458,67 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
schema := "public"
|
||||
if dbName != "" {
|
||||
schema = dbName
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
|
||||
tableName, schema)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 构建查询:如果没有指定schema,使用current_schema()
|
||||
var query string
|
||||
if schema != "" {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
|
||||
esc(table), esc(schema))
|
||||
} else {
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
kcu.column_name,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
|
||||
esc(table))
|
||||
}
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -365,9 +540,43 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
query := fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s'`, tableName)
|
||||
// 解析 schema.table 格式
|
||||
schema := strings.TrimSpace(dbName)
|
||||
table := strings.TrimSpace(tableName)
|
||||
|
||||
// 如果 tableName 包含 schema (格式: schema.table)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
parsedSchema := strings.TrimSpace(parts[0])
|
||||
parsedTable := strings.TrimSpace(parts[1])
|
||||
if parsedSchema != "" && parsedTable != "" {
|
||||
schema = parsedSchema
|
||||
table = parsedTable
|
||||
}
|
||||
}
|
||||
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
// 转义函数:处理单引号,移除双引号
|
||||
esc := func(s string) string {
|
||||
s = strings.Trim(s, "\"")
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// 构建查询:如果指定了schema,也加上schema条件
|
||||
var query string
|
||||
if schema != "" {
|
||||
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
|
||||
esc(table), esc(schema))
|
||||
} else {
|
||||
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
|
||||
esc(table))
|
||||
}
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -392,14 +601,13 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
schema := "public"
|
||||
if dbName != "" {
|
||||
schema = dbName
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`SELECT table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '%s'`, schema)
|
||||
// dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
|
||||
query := `
|
||||
SELECT table_schema, table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
AND table_schema NOT LIKE 'pg_%'
|
||||
ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
data, _, err := k.Query(query)
|
||||
if err != nil {
|
||||
@@ -408,8 +616,14 @@ func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinition
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||||
table := fmt.Sprintf("%v", row["table_name"])
|
||||
tableName := table
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||||
}
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: fmt.Sprintf("%v", row["table_name"]),
|
||||
TableName: tableName,
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -48,7 +49,7 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
m.conn = db
|
||||
m.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
|
||||
// Force verification
|
||||
if err := m.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
@@ -76,6 +77,20 @@ func (m *MySQLDB) Ping() error {
|
||||
return m.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := m.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if m.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -86,45 +101,18 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if m.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := m.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
var v interface{}
|
||||
val := values[i]
|
||||
b, ok := val.([]byte)
|
||||
if ok {
|
||||
if b == nil {
|
||||
v = nil
|
||||
} else {
|
||||
v = string(b)
|
||||
}
|
||||
} else {
|
||||
v = val
|
||||
}
|
||||
entry[col] = v
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Exec(query string) (int64, error) {
|
||||
@@ -159,12 +147,12 @@ func (m *MySQLDB) GetTables(dbName string) ([]string, error) {
|
||||
if dbName != "" {
|
||||
query = fmt.Sprintf("SHOW TABLES FROM `%s`", dbName)
|
||||
}
|
||||
|
||||
|
||||
data, _, err := m.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
for _, v := range row {
|
||||
@@ -185,7 +173,7 @@ func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
||||
if len(data) > 0 {
|
||||
if val, ok := data[0]["Create Table"]; ok {
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
@@ -215,12 +203,12 @@ func (m *MySQLDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefin
|
||||
Extra: fmt.Sprintf("%v", row["Extra"]),
|
||||
Comment: fmt.Sprintf("%v", row["Comment"]),
|
||||
}
|
||||
|
||||
|
||||
if row["Default"] != nil {
|
||||
d := fmt.Sprintf("%v", row["Default"])
|
||||
col.Default = &d
|
||||
}
|
||||
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -248,14 +236,14 @@ func (m *MySQLDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefini
|
||||
}
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if val, ok := row["Seq_in_index"]; ok {
|
||||
seq := 0
|
||||
if val, ok := row["Seq_in_index"]; ok {
|
||||
if f, ok := val.(float64); ok {
|
||||
seq = int(f)
|
||||
} else if i, ok := val.(int64); ok {
|
||||
seq = int(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["Key_name"]),
|
||||
@@ -345,12 +333,12 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
for _, update := range changes.Updates {
|
||||
var sets []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
for k, v := range update.Values {
|
||||
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
|
||||
if len(sets) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -360,7 +348,7 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("update requires keys")
|
||||
}
|
||||
@@ -376,13 +364,13 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e
|
||||
var cols []string
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
|
||||
|
||||
for k, v := range row {
|
||||
cols = append(cols, fmt.Sprintf("`%s`", k))
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
|
||||
if len(cols) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -19,6 +21,7 @@ import (
|
||||
type OracleDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
@@ -28,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
database = config.User // Default to user service/schema if empty?
|
||||
}
|
||||
|
||||
if config.UseSSH {
|
||||
_, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
// Oracle driver might not support custom dialer via DSN easily without extra config
|
||||
// But go-ora v2 supports some advanced options.
|
||||
// For simplicity, we assume standard TCP or we might need a workaround for SSH.
|
||||
// go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open.
|
||||
// But for now, let's just use the address.
|
||||
// SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...).
|
||||
// We might need to forward a local port if using SSH.
|
||||
// Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...",
|
||||
// we need to see if go-ora supports custom networks.
|
||||
// Checking go-ora docs (simulated): It supports "unix" and "tcp".
|
||||
// We might need to map the custom network to a local proxy.
|
||||
// For now, we will assume direct connection or handle SSH separately later.
|
||||
// We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial.
|
||||
// Note: go-ora connection string: oracle://user:pass@host:port/service
|
||||
// It parses host/port. It doesn't easily take a custom "network" parameter in URL.
|
||||
// We will proceed with standard TCP string.
|
||||
}
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "oracle",
|
||||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
@@ -61,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := o.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
o.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = o.getDSN(localConfig)
|
||||
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = o.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("oracle", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
@@ -75,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
|
||||
func (o *OracleDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if o.forwarder != nil {
|
||||
if err := o.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err)
|
||||
}
|
||||
o.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if o.conn != nil {
|
||||
return o.conn.Close()
|
||||
}
|
||||
@@ -94,6 +119,20 @@ func (o *OracleDB) Ping() error {
|
||||
return o.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (o *OracleDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if o.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := o.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if o.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -104,45 +143,18 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if o.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := o.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
var v interface{}
|
||||
val := values[i]
|
||||
b, ok := val.([]byte)
|
||||
if ok {
|
||||
if b == nil {
|
||||
v = nil
|
||||
} else {
|
||||
v = string(b)
|
||||
}
|
||||
} else {
|
||||
v = val
|
||||
}
|
||||
entry[col] = v
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (o *OracleDB) Exec(query string) (int64, error) {
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
||||
type PostgresDB struct {
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
|
||||
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||||
dbname := config.Database
|
||||
@@ -41,14 +48,49 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := p.getDSN(config)
|
||||
var dsn string
|
||||
var err error
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
|
||||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
p.forwarder = forwarder
|
||||
|
||||
// Parse local address
|
||||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||||
}
|
||||
|
||||
// Create a modified config pointing to local forwarder
|
||||
localConfig := config
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false // Disable SSH flag for DSN generation
|
||||
|
||||
dsn = p.getDSN(localConfig)
|
||||
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = p.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
p.conn = db
|
||||
p.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
|
||||
// Force verification
|
||||
if err := p.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
@@ -56,7 +98,17 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func (p *PostgresDB) Close() error {
|
||||
// Close SSH forwarder first if exists
|
||||
if p.forwarder != nil {
|
||||
if err := p.forwarder.Close(); err != nil {
|
||||
logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err)
|
||||
}
|
||||
p.forwarder = nil
|
||||
}
|
||||
|
||||
// Then close database connection
|
||||
if p.conn != nil {
|
||||
return p.conn.Close()
|
||||
}
|
||||
@@ -76,56 +128,42 @@ func (p *PostgresDB) Ping() error {
|
||||
return p.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
|
||||
rows, err := p.conn.Query(query)
|
||||
rows, err := p.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := p.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
var v interface{}
|
||||
val := values[i]
|
||||
b, ok := val.([]byte)
|
||||
if ok {
|
||||
if b == nil {
|
||||
v = nil
|
||||
} else {
|
||||
v = string(b)
|
||||
}
|
||||
} else {
|
||||
v = val
|
||||
}
|
||||
entry[col] = v
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if p.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
res, err := p.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Exec(query string) (int64, error) {
|
||||
@@ -159,7 +197,7 @@ func (p *PostgresDB) GetTables(dbName string) ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
schema, okSchema := row["schemaname"]
|
||||
@@ -180,21 +218,306 @@ func (p *PostgresDB) GetCreateStatement(dbName, tableName string) (string, error
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return []connection.ColumnDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
col_description(a.attrelid, a.attnum) AS comment,
|
||||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||||
LEFT JOIN (
|
||||
SELECT i.indrelid, a3.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||||
WHERE i.indisprimary
|
||||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||||
WHERE c.relkind IN ('r', 'p')
|
||||
AND n.nspname = '%s'
|
||||
AND c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum`, esc(schema), esc(table))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
col := connection.ColumnDefinition{
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if v, ok := row["comment"]; ok && v != nil {
|
||||
col.Comment = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
if v, ok := row["column_default"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||||
col.Extra = "auto_increment"
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return []connection.IndexDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
i.relname AS index_name,
|
||||
a.attname AS column_name,
|
||||
ix.indisunique AS is_unique,
|
||||
x.ordinality AS seq_in_index,
|
||||
am.amname AS index_type
|
||||
FROM pg_class t
|
||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||
JOIN pg_index ix ON t.oid = ix.indrelid
|
||||
JOIN pg_class i ON i.oid = ix.indexrelid
|
||||
JOIN pg_am am ON i.relam = am.oid
|
||||
JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE
|
||||
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
|
||||
WHERE t.relkind IN ('r', 'p')
|
||||
AND t.relname = '%s'
|
||||
AND n.nspname = '%s'
|
||||
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseBool := func(v interface{}) bool {
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return val
|
||||
case string:
|
||||
s := strings.ToLower(strings.TrimSpace(val))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
default:
|
||||
s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||||
}
|
||||
}
|
||||
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
// best effort
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
isUnique := false
|
||||
if v, ok := row["is_unique"]; ok && v != nil {
|
||||
isUnique = parseBool(v)
|
||||
}
|
||||
|
||||
nonUnique := 1
|
||||
if isUnique {
|
||||
nonUnique = 0
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if v, ok := row["seq_in_index"]; ok && v != nil {
|
||||
seq = parseInt(v)
|
||||
}
|
||||
|
||||
indexType := ""
|
||||
if v, ok := row["index_type"]; ok && v != nil {
|
||||
indexType = strings.ToUpper(fmt.Sprintf("%v", v))
|
||||
}
|
||||
if indexType == "" {
|
||||
indexType = "BTREE"
|
||||
}
|
||||
|
||||
idx := connection.IndexDefinition{
|
||||
Name: fmt.Sprintf("%v", row["index_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: indexType,
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
tc.constraint_name AS constraint_name,
|
||||
kcu.column_name AS column_name,
|
||||
ccu.table_schema AS foreign_table_schema,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_name = '%s'
|
||||
AND tc.table_schema = '%s'
|
||||
ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
refSchema := ""
|
||||
if v, ok := row["foreign_table_schema"]; ok && v != nil {
|
||||
refSchema = fmt.Sprintf("%v", v)
|
||||
}
|
||||
refTable := fmt.Sprintf("%v", row["foreign_table_name"])
|
||||
refTableName := refTable
|
||||
if strings.TrimSpace(refSchema) != "" {
|
||||
refTableName = fmt.Sprintf("%s.%s", refSchema, refTable)
|
||||
}
|
||||
|
||||
fk := connection.ForeignKeyDefinition{
|
||||
Name: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||||
RefTableName: refTableName,
|
||||
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
|
||||
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
schema := strings.TrimSpace(dbName)
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT trigger_name, action_timing, event_manipulation, action_statement
|
||||
FROM information_schema.triggers
|
||||
WHERE event_object_table = '%s'
|
||||
AND event_object_schema = '%s'
|
||||
ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema))
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
trig := connection.TriggerDefinition{
|
||||
Name: fmt.Sprintf("%v", row["trigger_name"]),
|
||||
Timing: fmt.Sprintf("%v", row["action_timing"]),
|
||||
Event: fmt.Sprintf("%v", row["event_manipulation"]),
|
||||
Statement: fmt.Sprintf("%v", row["action_statement"]),
|
||||
}
|
||||
triggers = append(triggers, trig)
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
return []connection.ColumnDefinitionWithTable{}, nil
|
||||
query := `
|
||||
SELECT table_schema, table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
AND table_schema NOT LIKE 'pg_%'
|
||||
ORDER BY table_schema, table_name, ordinal_position`
|
||||
|
||||
data, _, err := p.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, row := range data {
|
||||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||||
table := fmt.Sprintf("%v", row["table_name"])
|
||||
tableName := table
|
||||
if strings.TrimSpace(schema) != "" {
|
||||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||||
}
|
||||
|
||||
col := connection.ColumnDefinitionWithTable{
|
||||
TableName: tableName,
|
||||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||||
}
|
||||
cols = append(cols, col)
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
114
internal/db/query_value.go
Normal file
114
internal/db/query_value.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// normalizeQueryValue normalizes driver-returned values for UI/JSON transport.
|
||||
// 当前主要处理 []byte:如果是可读文本则转为 string,否则转为十六进制字符串,避免前端出现“空白值”。
|
||||
func normalizeQueryValue(v interface{}) interface{} {
|
||||
return normalizeQueryValueWithDBType(v, "")
|
||||
}
|
||||
|
||||
func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} {
|
||||
if b, ok := v.([]byte); ok {
|
||||
return bytesToDisplayValue(b, databaseTypeName)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func bytesToDisplayValue(b []byte, databaseTypeName string) interface{} {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
dbType := strings.ToUpper(strings.TrimSpace(databaseTypeName))
|
||||
if isBitLikeDBType(dbType) {
|
||||
if u, ok := bytesToUint64(b); ok {
|
||||
// JS number precision is limited; keep large bitmasks as string.
|
||||
const maxSafeInteger = 9007199254740991 // 2^53 - 1
|
||||
if u <= maxSafeInteger {
|
||||
return int64(u)
|
||||
}
|
||||
return fmt.Sprintf("%d", u)
|
||||
}
|
||||
}
|
||||
|
||||
if utf8.Valid(b) {
|
||||
s := string(b)
|
||||
if isMostlyPrintable(s) {
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: some drivers return BIT(1) as []byte{0} / []byte{1} without type info.
|
||||
if dbType == "" && len(b) == 1 && (b[0] == 0 || b[0] == 1) {
|
||||
return int64(b[0])
|
||||
}
|
||||
|
||||
return bytesToReadableString(b)
|
||||
}
|
||||
|
||||
func bytesToReadableString(b []byte) interface{} {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "0x" + hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func isBitLikeDBType(typeName string) bool {
|
||||
if typeName == "" {
|
||||
return false
|
||||
}
|
||||
switch typeName {
|
||||
case "BIT", "VARBIT":
|
||||
return true
|
||||
default:
|
||||
}
|
||||
return strings.HasPrefix(typeName, "BIT")
|
||||
}
|
||||
|
||||
func bytesToUint64(b []byte) (uint64, bool) {
|
||||
if len(b) == 0 || len(b) > 8 {
|
||||
return 0, false
|
||||
}
|
||||
var u uint64
|
||||
for _, v := range b {
|
||||
u = (u << 8) | uint64(v)
|
||||
}
|
||||
return u, true
|
||||
}
|
||||
|
||||
func isMostlyPrintable(s string) bool {
|
||||
if s == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
total := 0
|
||||
printable := 0
|
||||
for _, r := range s {
|
||||
total++
|
||||
switch r {
|
||||
case '\n', '\r', '\t':
|
||||
printable++
|
||||
continue
|
||||
default:
|
||||
}
|
||||
if unicode.IsPrint(r) {
|
||||
printable++
|
||||
}
|
||||
}
|
||||
|
||||
// 允许少量不可见字符,避免把正常文本误判为二进制。
|
||||
return printable*100 >= total*90
|
||||
}
|
||||
44
internal/db/query_value_test.go
Normal file
44
internal/db/query_value_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_BitBytes(t *testing.T) {
|
||||
v := normalizeQueryValueWithDBType([]byte{0x00}, "BIT")
|
||||
if v != int64(0) {
|
||||
t.Fatalf("BIT 0x00 期望为 0,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0x01}, "bit")
|
||||
if v != int64(1) {
|
||||
t.Fatalf("BIT 0x01 期望为 1,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0x01, 0x02}, "BIT VARYING")
|
||||
if v != int64(258) {
|
||||
t.Fatalf("BIT 0x0102 期望为 258,实际=%v(%T)", v, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_BitLargeAsString(t *testing.T) {
|
||||
v := normalizeQueryValueWithDBType([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "BIT")
|
||||
if s, ok := v.(string); !ok || s != "18446744073709551615" {
|
||||
t.Fatalf("BIT 0xffffffffffffffff 期望为 string(18446744073709551615),实际=%v(%T)", v, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeQueryValueWithDBType_ByteFallbacks(t *testing.T) {
|
||||
v := normalizeQueryValueWithDBType([]byte("abc"), "")
|
||||
if v != "abc" {
|
||||
t.Fatalf("文本 []byte 期望返回 string,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0x00}, "")
|
||||
if v != int64(0) {
|
||||
t.Fatalf("未知类型 0x00 期望返回 0,实际=%v(%T)", v, v)
|
||||
}
|
||||
|
||||
v = normalizeQueryValueWithDBType([]byte{0xff}, "")
|
||||
if v != "0xff" {
|
||||
t.Fatalf("未知类型 0xff 期望返回 0xff,实际=%v(%T)", v, v)
|
||||
}
|
||||
}
|
||||
46
internal/db/scan_rows.go
Normal file
46
internal/db/scan_rows.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil || len(colTypes) != len(columns) {
|
||||
colTypes = nil
|
||||
}
|
||||
|
||||
resultData := make([]map[string]interface{}, 0)
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{}, len(columns))
|
||||
for i, col := range columns {
|
||||
dbTypeName := ""
|
||||
if colTypes != nil && i < len(colTypes) && colTypes[i] != nil {
|
||||
dbTypeName = colTypes[i].DatabaseTypeName()
|
||||
}
|
||||
entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName)
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return resultData, columns, err
|
||||
}
|
||||
return resultData, columns, nil
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -17,14 +19,14 @@ type SQLiteDB struct {
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := config.Host
|
||||
dsn := config.Host
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
s.conn = db
|
||||
s.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
|
||||
// Force verification
|
||||
if err := s.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
@@ -52,6 +54,20 @@ func (s *SQLiteDB) Ping() error {
|
||||
return s.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
}
|
||||
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("connection not open")
|
||||
@@ -62,45 +78,18 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if s.conn == nil {
|
||||
return 0, fmt.Errorf("connection not open")
|
||||
}
|
||||
res, err := s.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var resultData []map[string]interface{}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range columns {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
var v interface{}
|
||||
val := values[i]
|
||||
b, ok := val.([]byte)
|
||||
if ok {
|
||||
if b == nil {
|
||||
v = nil
|
||||
} else {
|
||||
v = string(b)
|
||||
}
|
||||
} else {
|
||||
v = val
|
||||
}
|
||||
entry[col] = v
|
||||
}
|
||||
resultData = append(resultData, entry)
|
||||
}
|
||||
|
||||
return resultData, columns, nil
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) Exec(query string) (int64, error) {
|
||||
@@ -124,7 +113,7 @@ func (s *SQLiteDB) GetTables(dbName string) ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
var tables []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["name"]; ok {
|
||||
@@ -149,21 +138,336 @@ func (s *SQLiteDB) GetCreateStatement(dbName, tableName string) (string, error)
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return []connection.ColumnDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
|
||||
// cid, name, type, notnull, dflt_value, pk
|
||||
data, _, err := s.Query(fmt.Sprintf("PRAGMA table_info('%s')", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
getStr := func(row map[string]interface{}, key string) string {
|
||||
if v, ok := row[key]; ok && v != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
if v, ok := row[strings.ToUpper(key)]; ok && v != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var columns []connection.ColumnDefinition
|
||||
for _, row := range data {
|
||||
notnull := 0
|
||||
if v, ok := row["notnull"]; ok && v != nil {
|
||||
notnull = parseInt(v)
|
||||
} else if v, ok := row["NOTNULL"]; ok && v != nil {
|
||||
notnull = parseInt(v)
|
||||
}
|
||||
|
||||
pk := 0
|
||||
if v, ok := row["pk"]; ok && v != nil {
|
||||
pk = parseInt(v)
|
||||
} else if v, ok := row["PK"]; ok && v != nil {
|
||||
pk = parseInt(v)
|
||||
}
|
||||
|
||||
nullable := "YES"
|
||||
if notnull == 1 {
|
||||
nullable = "NO"
|
||||
}
|
||||
|
||||
key := ""
|
||||
if pk == 1 {
|
||||
key = "PRI"
|
||||
}
|
||||
|
||||
col := connection.ColumnDefinition{
|
||||
Name: getStr(row, "name"),
|
||||
Type: getStr(row, "type"),
|
||||
Nullable: nullable,
|
||||
Key: key,
|
||||
Extra: "",
|
||||
Comment: "",
|
||||
}
|
||||
|
||||
if v, ok := row["dflt_value"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
} else if v, ok := row["DFLT_VALUE"]; ok && v != nil {
|
||||
def := fmt.Sprintf("%v", v)
|
||||
col.Default = &def
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return []connection.IndexDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
data, _, err := s.Query(fmt.Sprintf("PRAGMA index_list('%s')", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var indexes []connection.IndexDefinition
|
||||
for _, row := range data {
|
||||
indexName := ""
|
||||
if v, ok := row["name"]; ok && v != nil {
|
||||
indexName = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["NAME"]; ok && v != nil {
|
||||
indexName = fmt.Sprintf("%v", v)
|
||||
}
|
||||
if strings.TrimSpace(indexName) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
unique := 0
|
||||
if v, ok := row["unique"]; ok && v != nil {
|
||||
unique = parseInt(v)
|
||||
} else if v, ok := row["UNIQUE"]; ok && v != nil {
|
||||
unique = parseInt(v)
|
||||
}
|
||||
nonUnique := 1
|
||||
if unique == 1 {
|
||||
nonUnique = 0
|
||||
}
|
||||
|
||||
cols, _, err := s.Query(fmt.Sprintf("PRAGMA index_info('%s')", esc(indexName)))
|
||||
if err != nil {
|
||||
// skip broken index
|
||||
continue
|
||||
}
|
||||
|
||||
for _, c := range cols {
|
||||
colName := ""
|
||||
if v, ok := c["name"]; ok && v != nil {
|
||||
colName = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := c["NAME"]; ok && v != nil {
|
||||
colName = fmt.Sprintf("%v", v)
|
||||
}
|
||||
if strings.TrimSpace(colName) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
seq := 0
|
||||
if v, ok := c["seqno"]; ok && v != nil {
|
||||
seq = parseInt(v) + 1
|
||||
} else if v, ok := c["SEQNO"]; ok && v != nil {
|
||||
seq = parseInt(v) + 1
|
||||
}
|
||||
|
||||
indexes = append(indexes, connection.IndexDefinition{
|
||||
Name: indexName,
|
||||
ColumnName: colName,
|
||||
NonUnique: nonUnique,
|
||||
SeqInIndex: seq,
|
||||
IndexType: "BTREE",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
return []connection.ForeignKeyDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
|
||||
data, _, err := s.Query(fmt.Sprintf("PRAGMA foreign_key_list('%s')", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parseInt := func(v interface{}) int {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||||
return n
|
||||
default:
|
||||
var n int
|
||||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
var fks []connection.ForeignKeyDefinition
|
||||
for _, row := range data {
|
||||
id := 0
|
||||
if v, ok := row["id"]; ok && v != nil {
|
||||
id = parseInt(v)
|
||||
} else if v, ok := row["ID"]; ok && v != nil {
|
||||
id = parseInt(v)
|
||||
}
|
||||
|
||||
refTable := ""
|
||||
if v, ok := row["table"]; ok && v != nil {
|
||||
refTable = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["TABLE"]; ok && v != nil {
|
||||
refTable = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
fromCol := ""
|
||||
if v, ok := row["from"]; ok && v != nil {
|
||||
fromCol = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["FROM"]; ok && v != nil {
|
||||
fromCol = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
toCol := ""
|
||||
if v, ok := row["to"]; ok && v != nil {
|
||||
toCol = fmt.Sprintf("%v", v)
|
||||
} else if v, ok := row["TO"]; ok && v != nil {
|
||||
toCol = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("fk_%s_%d", table, id)
|
||||
fks = append(fks, connection.ForeignKeyDefinition{
|
||||
Name: name,
|
||||
ColumnName: fromCol,
|
||||
RefTableName: refTable,
|
||||
RefColumnName: toCol,
|
||||
ConstraintName: name,
|
||||
})
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return []connection.TriggerDefinition{}, nil
|
||||
table := strings.TrimSpace(tableName)
|
||||
if table == "" {
|
||||
return nil, fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
esc := func(v string) string { return strings.ReplaceAll(v, "'", "''") }
|
||||
|
||||
data, _, err := s.Query(fmt.Sprintf("SELECT name AS trigger_name, sql AS statement FROM sqlite_master WHERE type='trigger' AND tbl_name='%s' ORDER BY name", esc(table)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var triggers []connection.TriggerDefinition
|
||||
for _, row := range data {
|
||||
name := fmt.Sprintf("%v", row["trigger_name"])
|
||||
stmt := ""
|
||||
if v, ok := row["statement"]; ok && v != nil {
|
||||
stmt = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(stmt)
|
||||
timing := ""
|
||||
switch {
|
||||
case strings.Contains(upper, " BEFORE "):
|
||||
timing = "BEFORE"
|
||||
case strings.Contains(upper, " AFTER "):
|
||||
timing = "AFTER"
|
||||
case strings.Contains(upper, " INSTEAD OF "):
|
||||
timing = "INSTEAD OF"
|
||||
}
|
||||
|
||||
event := ""
|
||||
switch {
|
||||
case strings.Contains(upper, " INSERT "):
|
||||
event = "INSERT"
|
||||
case strings.Contains(upper, " UPDATE "):
|
||||
event = "UPDATE"
|
||||
case strings.Contains(upper, " DELETE "):
|
||||
event = "DELETE"
|
||||
}
|
||||
|
||||
triggers = append(triggers, connection.TriggerDefinition{
|
||||
Name: name,
|
||||
Timing: timing,
|
||||
Event: event,
|
||||
Statement: stmt,
|
||||
})
|
||||
}
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
return []connection.ColumnDefinitionWithTable{}, nil
|
||||
tables, err := s.GetTables(dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cols []connection.ColumnDefinitionWithTable
|
||||
for _, table := range tables {
|
||||
// Skip internal tables
|
||||
if strings.HasPrefix(strings.ToLower(table), "sqlite_") {
|
||||
continue
|
||||
}
|
||||
columns, err := s.GetColumns("", table)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, col := range columns {
|
||||
cols = append(cols, connection.ColumnDefinitionWithTable{
|
||||
TableName: table,
|
||||
Name: col.Name,
|
||||
Type: col.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ package ssh
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -110,3 +112,264 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
|
||||
|
||||
return netName, nil
|
||||
}
|
||||
|
||||
// sshClientCache stores SSH clients to avoid creating multiple connections
|
||||
var (
|
||||
sshClientCache = make(map[string]*ssh.Client)
|
||||
sshClientCacheMu sync.RWMutex
|
||||
localForwarders = make(map[string]*LocalForwarder)
|
||||
forwarderMu sync.RWMutex
|
||||
)
|
||||
|
||||
// LocalForwarder represents a local port forwarder through SSH
|
||||
type LocalForwarder struct {
|
||||
LocalAddr string
|
||||
RemoteAddr string
|
||||
SSHClient *ssh.Client
|
||||
listener net.Listener
|
||||
closeChan chan struct{}
|
||||
closeOnce sync.Once // 防止重复关闭
|
||||
closed bool // 关闭状态标记
|
||||
closedMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewLocalForwarder creates a new local port forwarder
|
||||
// It listens on a random local port and forwards all connections through SSH tunnel
|
||||
func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
|
||||
client, err := GetOrCreateSSHClient(sshConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
|
||||
}
|
||||
|
||||
// Listen on localhost with a random port
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建本地监听器失败:%w", err)
|
||||
}
|
||||
|
||||
localAddr := listener.Addr().String()
|
||||
remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)
|
||||
|
||||
forwarder := &LocalForwarder{
|
||||
LocalAddr: localAddr,
|
||||
RemoteAddr: remoteAddr,
|
||||
SSHClient: client,
|
||||
listener: listener,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start forwarding in background
|
||||
go forwarder.forward()
|
||||
|
||||
logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr)
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
// forward handles the port forwarding
|
||||
func (f *LocalForwarder) forward() {
|
||||
for {
|
||||
localConn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
// Check if we're shutting down
|
||||
select {
|
||||
case <-f.closeChan:
|
||||
return
|
||||
default:
|
||||
logger.Warnf("接受本地连接失败:%v", err)
|
||||
// listener可能已关闭,退出循环
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go f.handleConnection(localConn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single connection
|
||||
func (f *LocalForwarder) handleConnection(localConn net.Conn) {
|
||||
defer localConn.Close()
|
||||
|
||||
// Connect to remote through SSH with timeout
|
||||
remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr)
|
||||
if err != nil {
|
||||
logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
|
||||
// Bidirectional copy with error channel
|
||||
errc := make(chan error, 2)
|
||||
|
||||
// Copy from local to remote
|
||||
go func() {
|
||||
_, err := io.Copy(remoteConn, localConn)
|
||||
if err != nil {
|
||||
logger.Warnf("本地->远程数据复制错误:%v", err)
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
// Copy from remote to local
|
||||
go func() {
|
||||
_, err := io.Copy(localConn, remoteConn)
|
||||
if err != nil {
|
||||
logger.Warnf("远程->本地数据复制错误:%v", err)
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
// Wait for BOTH goroutines to complete
|
||||
<-errc
|
||||
<-errc
|
||||
}
|
||||
|
||||
// Close closes the forwarder (thread-safe, can be called multiple times)
|
||||
func (f *LocalForwarder) Close() error {
|
||||
var err error
|
||||
f.closeOnce.Do(func() {
|
||||
f.closedMu.Lock()
|
||||
f.closed = true
|
||||
f.closedMu.Unlock()
|
||||
|
||||
close(f.closeChan)
|
||||
err = f.listener.Close()
|
||||
if err != nil {
|
||||
logger.Warnf("关闭端口转发监听器失败:%v", err)
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// IsClosed returns whether the forwarder is closed
|
||||
func (f *LocalForwarder) IsClosed() bool {
|
||||
f.closedMu.RLock()
|
||||
defer f.closedMu.RUnlock()
|
||||
return f.closed
|
||||
}
|
||||
|
||||
// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one
|
||||
func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
|
||||
key := fmt.Sprintf("%s:%d:%s->%s:%d",
|
||||
sshConfig.Host, sshConfig.Port, sshConfig.User,
|
||||
remoteHost, remotePort)
|
||||
|
||||
forwarderMu.RLock()
|
||||
forwarder, exists := localForwarders[key]
|
||||
forwarderMu.RUnlock()
|
||||
|
||||
// Check if exists and is still valid
|
||||
if exists && forwarder != nil && !forwarder.IsClosed() {
|
||||
logger.Infof("复用已有端口转发:%s", key)
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
// Remove stale forwarder from cache
|
||||
if exists {
|
||||
forwarderMu.Lock()
|
||||
delete(localForwarders, key)
|
||||
forwarderMu.Unlock()
|
||||
}
|
||||
|
||||
forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarderMu.Lock()
|
||||
localForwarders[key] = forwarder
|
||||
forwarderMu.Unlock()
|
||||
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
// CloseAllForwarders closes all local forwarders
|
||||
func CloseAllForwarders() {
|
||||
forwarderMu.Lock()
|
||||
defer forwarderMu.Unlock()
|
||||
|
||||
for key, forwarder := range localForwarders {
|
||||
if forwarder != nil {
|
||||
_ = forwarder.Close()
|
||||
logger.Infof("已关闭端口转发:%s", key)
|
||||
}
|
||||
}
|
||||
localForwarders = make(map[string]*LocalForwarder)
|
||||
}
|
||||
|
||||
|
||||
// getSSHClientCacheKey generates a unique cache key for SSH config
|
||||
func getSSHClientCacheKey(config connection.SSHConfig) string {
|
||||
return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User)
|
||||
}
|
||||
|
||||
// GetOrCreateSSHClient returns a cached SSH client or creates a new one
|
||||
func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) {
|
||||
key := getSSHClientCacheKey(config)
|
||||
|
||||
sshClientCacheMu.RLock()
|
||||
client, exists := sshClientCache[key]
|
||||
sshClientCacheMu.RUnlock()
|
||||
|
||||
if exists && client != nil {
|
||||
// Test if connection is still alive by creating a test session
|
||||
session, err := client.NewSession()
|
||||
if err == nil {
|
||||
session.Close()
|
||||
logger.Infof("复用已有 SSH 连接:%s", key)
|
||||
return client, nil
|
||||
}
|
||||
// Connection is dead, remove from cache
|
||||
logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err)
|
||||
sshClientCacheMu.Lock()
|
||||
delete(sshClientCache, key)
|
||||
sshClientCacheMu.Unlock()
|
||||
// Try to close the dead client
|
||||
_ = client.Close()
|
||||
}
|
||||
|
||||
// Create new SSH client
|
||||
client, err := connectSSH(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the client
|
||||
sshClientCacheMu.Lock()
|
||||
sshClientCache[key] = client
|
||||
sshClientCacheMu.Unlock()
|
||||
|
||||
logger.Infof("已缓存 SSH 连接:%s", key)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// DialThroughSSH creates a connection through SSH tunnel
|
||||
// This is a generic dialer that can be used by any database driver
|
||||
func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) {
|
||||
client, err := GetOrCreateSSHClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
|
||||
}
|
||||
|
||||
conn, err := client.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err)
|
||||
}
|
||||
|
||||
logger.Infof("已通过 SSH 隧道连接到:%s", address)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// CloseAllSSHClients closes all cached SSH clients
|
||||
func CloseAllSSHClients() {
|
||||
sshClientCacheMu.Lock()
|
||||
defer sshClientCacheMu.Unlock()
|
||||
|
||||
for key, client := range sshClientCache {
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
logger.Infof("已关闭 SSH 连接:%s", key)
|
||||
}
|
||||
}
|
||||
sshClientCache = make(map[string]*ssh.Client)
|
||||
}
|
||||
|
||||
|
||||
198
internal/sync/analyze.go
Normal file
198
internal/sync/analyze.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TableDiffSummary struct {
|
||||
Table string `json:"table"`
|
||||
PKColumn string `json:"pkColumn,omitempty"`
|
||||
CanSync bool `json:"canSync"`
|
||||
Inserts int `json:"inserts"`
|
||||
Updates int `json:"updates"`
|
||||
Deletes int `json:"deletes"`
|
||||
Same int `json:"same"`
|
||||
Message string `json:"message,omitempty"`
|
||||
HasSchema bool `json:"hasSchema,omitempty"`
|
||||
}
|
||||
|
||||
type SyncAnalyzeResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Tables []TableDiffSummary `json:"tables"`
|
||||
}
|
||||
|
||||
func (s *SyncEngine) Analyze(config SyncConfig) SyncAnalyzeResult {
|
||||
result := SyncAnalyzeResult{Success: true, Tables: []TableDiffSummary{}}
|
||||
|
||||
contentRaw := strings.ToLower(strings.TrimSpace(config.Content))
|
||||
syncSchema := false
|
||||
syncData := true
|
||||
switch contentRaw {
|
||||
case "", "data":
|
||||
syncData = true
|
||||
case "schema":
|
||||
syncSchema = true
|
||||
syncData = false
|
||||
case "both":
|
||||
syncSchema = true
|
||||
syncData = true
|
||||
default:
|
||||
s.appendLog(config.JobID, nil, "warn", fmt.Sprintf("未知同步内容 %q,已自动使用仅同步数据", config.Content))
|
||||
syncData = true
|
||||
}
|
||||
|
||||
totalTables := len(config.Tables)
|
||||
s.progress(config.JobID, 0, totalTables, "", "差异分析开始")
|
||||
|
||||
sourceDB, err := db.NewDatabase(config.SourceConfig.Type)
|
||||
if err != nil {
|
||||
logger.Error(err, "初始化源数据库驱动失败:类型=%s", config.SourceConfig.Type)
|
||||
return SyncAnalyzeResult{Success: false, Message: "初始化源数据库驱动失败: " + err.Error()}
|
||||
}
|
||||
targetDB, err := db.NewDatabase(config.TargetConfig.Type)
|
||||
if err != nil {
|
||||
logger.Error(err, "初始化目标数据库驱动失败:类型=%s", config.TargetConfig.Type)
|
||||
return SyncAnalyzeResult{Success: false, Message: "初始化目标数据库驱动失败: " + err.Error()}
|
||||
}
|
||||
|
||||
// Connect Source
|
||||
if err := sourceDB.Connect(config.SourceConfig); err != nil {
|
||||
logger.Error(err, "源数据库连接失败:%s", formatConnSummaryForSync(config.SourceConfig))
|
||||
return SyncAnalyzeResult{Success: false, Message: "源数据库连接失败: " + err.Error()}
|
||||
}
|
||||
defer sourceDB.Close()
|
||||
|
||||
// Connect Target
|
||||
if err := targetDB.Connect(config.TargetConfig); err != nil {
|
||||
logger.Error(err, "目标数据库连接失败:%s", formatConnSummaryForSync(config.TargetConfig))
|
||||
return SyncAnalyzeResult{Success: false, Message: "目标数据库连接失败: " + err.Error()}
|
||||
}
|
||||
defer targetDB.Close()
|
||||
|
||||
for i, tableName := range config.Tables {
|
||||
func() {
|
||||
s.progress(config.JobID, i, totalTables, tableName, fmt.Sprintf("分析表(%d/%d)", i+1, totalTables))
|
||||
|
||||
summary := TableDiffSummary{
|
||||
Table: tableName,
|
||||
CanSync: false,
|
||||
Inserts: 0,
|
||||
Updates: 0,
|
||||
Deletes: 0,
|
||||
Same: 0,
|
||||
Message: "",
|
||||
HasSchema: syncSchema,
|
||||
}
|
||||
|
||||
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
|
||||
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
|
||||
sourceQueryTable := qualifiedNameForQuery(config.SourceConfig.Type, sourceSchema, sourceTable, tableName)
|
||||
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
|
||||
|
||||
cols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
|
||||
if err != nil {
|
||||
summary.Message = "获取源表字段失败: " + err.Error()
|
||||
result.Tables = append(result.Tables, summary)
|
||||
return
|
||||
}
|
||||
|
||||
if !syncData {
|
||||
summary.CanSync = true
|
||||
summary.Message = "仅同步结构,未执行数据差异分析"
|
||||
result.Tables = append(result.Tables, summary)
|
||||
return
|
||||
}
|
||||
|
||||
pkCols := make([]string, 0, 2)
|
||||
for _, c := range cols {
|
||||
if c.Key == "PRI" || c.Key == "PK" {
|
||||
pkCols = append(pkCols, c.Name)
|
||||
}
|
||||
}
|
||||
if len(pkCols) == 0 {
|
||||
summary.Message = "无主键,不支持数据对比/同步"
|
||||
result.Tables = append(result.Tables, summary)
|
||||
return
|
||||
}
|
||||
if len(pkCols) > 1 {
|
||||
summary.Message = fmt.Sprintf("复合主键(%s),暂不支持数据对比/同步", strings.Join(pkCols, ","))
|
||||
result.Tables = append(result.Tables, summary)
|
||||
return
|
||||
}
|
||||
summary.PKColumn = pkCols[0]
|
||||
|
||||
// Query data for diff
|
||||
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.SourceConfig.Type, sourceQueryTable)))
|
||||
if err != nil {
|
||||
summary.Message = "读取源表失败: " + err.Error()
|
||||
result.Tables = append(result.Tables, summary)
|
||||
return
|
||||
}
|
||||
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable)))
|
||||
if err != nil {
|
||||
summary.Message = "读取目标表失败: " + err.Error()
|
||||
result.Tables = append(result.Tables, summary)
|
||||
return
|
||||
}
|
||||
|
||||
pkCol := summary.PKColumn
|
||||
targetMap := make(map[string]map[string]interface{}, len(targetRows))
|
||||
for _, row := range targetRows {
|
||||
if row[pkCol] == nil {
|
||||
continue
|
||||
}
|
||||
pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol]))
|
||||
if pkVal == "" || pkVal == "<nil>" {
|
||||
continue
|
||||
}
|
||||
targetMap[pkVal] = row
|
||||
}
|
||||
|
||||
sourcePKSet := make(map[string]struct{}, len(sourceRows))
|
||||
for _, sRow := range sourceRows {
|
||||
if sRow[pkCol] == nil {
|
||||
continue
|
||||
}
|
||||
pkVal := strings.TrimSpace(fmt.Sprintf("%v", sRow[pkCol]))
|
||||
if pkVal == "" || pkVal == "<nil>" {
|
||||
continue
|
||||
}
|
||||
sourcePKSet[pkVal] = struct{}{}
|
||||
|
||||
if tRow, exists := targetMap[pkVal]; exists {
|
||||
changed := false
|
||||
for k, v := range sRow {
|
||||
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
summary.Updates++
|
||||
} else {
|
||||
summary.Same++
|
||||
}
|
||||
} else {
|
||||
summary.Inserts++
|
||||
}
|
||||
}
|
||||
|
||||
for pkVal := range targetMap {
|
||||
if _, ok := sourcePKSet[pkVal]; !ok {
|
||||
summary.Deletes++
|
||||
}
|
||||
}
|
||||
|
||||
summary.CanSync = true
|
||||
result.Tables = append(result.Tables, summary)
|
||||
}()
|
||||
}
|
||||
|
||||
s.progress(config.JobID, totalTables, totalTables, "", "差异分析完成")
|
||||
result.Message = fmt.Sprintf("已完成 %d 张表的差异分析", len(result.Tables))
|
||||
return result
|
||||
}
|
||||
164
internal/sync/preview.go
Normal file
164
internal/sync/preview.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/db"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PreviewRow struct {
|
||||
PK string `json:"pk"`
|
||||
Row map[string]interface{} `json:"row"`
|
||||
}
|
||||
|
||||
type PreviewUpdateRow struct {
|
||||
PK string `json:"pk"`
|
||||
ChangedColumns []string `json:"changedColumns"`
|
||||
Source map[string]interface{} `json:"source"`
|
||||
Target map[string]interface{} `json:"target"`
|
||||
}
|
||||
|
||||
type TableDiffPreview struct {
|
||||
Table string `json:"table"`
|
||||
PKColumn string `json:"pkColumn"`
|
||||
TotalInserts int `json:"totalInserts"`
|
||||
TotalUpdates int `json:"totalUpdates"`
|
||||
TotalDeletes int `json:"totalDeletes"`
|
||||
Inserts []PreviewRow `json:"inserts"`
|
||||
Updates []PreviewUpdateRow `json:"updates"`
|
||||
Deletes []PreviewRow `json:"deletes"`
|
||||
}
|
||||
|
||||
func (s *SyncEngine) Preview(config SyncConfig, tableName string, limit int) (TableDiffPreview, error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
|
||||
sourceDB, err := db.NewDatabase(config.SourceConfig.Type)
|
||||
if err != nil {
|
||||
return TableDiffPreview{}, fmt.Errorf("初始化源数据库驱动失败: %w", err)
|
||||
}
|
||||
targetDB, err := db.NewDatabase(config.TargetConfig.Type)
|
||||
if err != nil {
|
||||
return TableDiffPreview{}, fmt.Errorf("初始化目标数据库驱动失败: %w", err)
|
||||
}
|
||||
|
||||
if err := sourceDB.Connect(config.SourceConfig); err != nil {
|
||||
return TableDiffPreview{}, fmt.Errorf("源数据库连接失败: %w", err)
|
||||
}
|
||||
defer sourceDB.Close()
|
||||
|
||||
if err := targetDB.Connect(config.TargetConfig); err != nil {
|
||||
return TableDiffPreview{}, fmt.Errorf("目标数据库连接失败: %w", err)
|
||||
}
|
||||
defer targetDB.Close()
|
||||
|
||||
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
|
||||
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
|
||||
sourceQueryTable := qualifiedNameForQuery(config.SourceConfig.Type, sourceSchema, sourceTable, tableName)
|
||||
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
|
||||
|
||||
cols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
|
||||
if err != nil {
|
||||
return TableDiffPreview{}, fmt.Errorf("获取源表字段失败: %w", err)
|
||||
}
|
||||
|
||||
pkCols := make([]string, 0, 2)
|
||||
for _, c := range cols {
|
||||
if c.Key == "PRI" || c.Key == "PK" {
|
||||
pkCols = append(pkCols, c.Name)
|
||||
}
|
||||
}
|
||||
if len(pkCols) == 0 {
|
||||
return TableDiffPreview{}, fmt.Errorf("无主键,不支持数据预览")
|
||||
}
|
||||
if len(pkCols) > 1 {
|
||||
return TableDiffPreview{}, fmt.Errorf("复合主键(%s),暂不支持数据预览", strings.Join(pkCols, ","))
|
||||
}
|
||||
pkCol := pkCols[0]
|
||||
|
||||
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.SourceConfig.Type, sourceQueryTable)))
|
||||
if err != nil {
|
||||
return TableDiffPreview{}, fmt.Errorf("读取源表失败: %w", err)
|
||||
}
|
||||
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable)))
|
||||
if err != nil {
|
||||
return TableDiffPreview{}, fmt.Errorf("读取目标表失败: %w", err)
|
||||
}
|
||||
|
||||
targetMap := make(map[string]map[string]interface{}, len(targetRows))
|
||||
for _, row := range targetRows {
|
||||
if row[pkCol] == nil {
|
||||
continue
|
||||
}
|
||||
pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol]))
|
||||
if pkVal == "" || pkVal == "<nil>" {
|
||||
continue
|
||||
}
|
||||
targetMap[pkVal] = row
|
||||
}
|
||||
|
||||
out := TableDiffPreview{
|
||||
Table: tableName,
|
||||
PKColumn: pkCol,
|
||||
TotalInserts: 0,
|
||||
TotalUpdates: 0,
|
||||
TotalDeletes: 0,
|
||||
Inserts: make([]PreviewRow, 0),
|
||||
Updates: make([]PreviewUpdateRow, 0),
|
||||
Deletes: make([]PreviewRow, 0),
|
||||
}
|
||||
|
||||
sourcePKSet := make(map[string]struct{}, len(sourceRows))
|
||||
for _, sRow := range sourceRows {
|
||||
if sRow[pkCol] == nil {
|
||||
continue
|
||||
}
|
||||
pkVal := strings.TrimSpace(fmt.Sprintf("%v", sRow[pkCol]))
|
||||
if pkVal == "" || pkVal == "<nil>" {
|
||||
continue
|
||||
}
|
||||
sourcePKSet[pkVal] = struct{}{}
|
||||
|
||||
if tRow, exists := targetMap[pkVal]; exists {
|
||||
changedColumns := make([]string, 0)
|
||||
for k, v := range sRow {
|
||||
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
|
||||
changedColumns = append(changedColumns, k)
|
||||
}
|
||||
}
|
||||
if len(changedColumns) > 0 {
|
||||
out.TotalUpdates++
|
||||
if len(out.Updates) < limit {
|
||||
out.Updates = append(out.Updates, PreviewUpdateRow{
|
||||
PK: pkVal,
|
||||
ChangedColumns: changedColumns,
|
||||
Source: sRow,
|
||||
Target: tRow,
|
||||
})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
out.TotalInserts++
|
||||
if len(out.Inserts) < limit {
|
||||
out.Inserts = append(out.Inserts, PreviewRow{PK: pkVal, Row: sRow})
|
||||
}
|
||||
}
|
||||
|
||||
for pkVal, row := range targetMap {
|
||||
if _, ok := sourcePKSet[pkVal]; ok {
|
||||
continue
|
||||
}
|
||||
out.TotalDeletes++
|
||||
if len(out.Deletes) < limit {
|
||||
out.Deletes = append(out.Deletes, PreviewRow{PK: pkVal, Row: row})
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
58
internal/sync/row_selection.go
Normal file
58
internal/sync/row_selection.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func filterRowsByPKSelection(pkCol string, rows []map[string]interface{}, enabled bool, selectedPKs []string) []map[string]interface{} {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return rows
|
||||
}
|
||||
if len(selectedPKs) == 0 {
|
||||
return rows
|
||||
}
|
||||
|
||||
set := make(map[string]struct{}, len(selectedPKs))
|
||||
for _, pk := range selectedPKs {
|
||||
set[pk] = struct{}{}
|
||||
}
|
||||
|
||||
out := make([]map[string]interface{}, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
pkStr := fmt.Sprintf("%v", row[pkCol])
|
||||
if _, ok := set[pkStr]; ok {
|
||||
out = append(out, row)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterUpdatesByPKSelection(pkCol string, updates []connection.UpdateRow, enabled bool, selectedPKs []string) []connection.UpdateRow {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return updates
|
||||
}
|
||||
if len(selectedPKs) == 0 {
|
||||
return updates
|
||||
}
|
||||
|
||||
set := make(map[string]struct{}, len(selectedPKs))
|
||||
for _, pk := range selectedPKs {
|
||||
set[pk] = struct{}{}
|
||||
}
|
||||
|
||||
out := make([]connection.UpdateRow, 0, len(updates))
|
||||
for _, u := range updates {
|
||||
pkStr := fmt.Sprintf("%v", u.Keys[pkCol])
|
||||
if _, ok := set[pkStr]; ok {
|
||||
out = append(out, u)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
97
internal/sync/schema_align.go
Normal file
97
internal/sync/schema_align.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func collectRequiredColumns(inserts []map[string]interface{}, updates []connection.UpdateRow) map[string]string {
|
||||
// key: lower(columnName), value: original columnName
|
||||
required := make(map[string]string)
|
||||
for _, row := range inserts {
|
||||
for k := range row {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := required[key]; !exists {
|
||||
required[key] = k
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, u := range updates {
|
||||
for k := range u.Values {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := required[key]; !exists {
|
||||
required[key] = k
|
||||
}
|
||||
}
|
||||
}
|
||||
return required
|
||||
}
|
||||
|
||||
func filterInsertRows(inserts []map[string]interface{}, allowedLower map[string]struct{}) []map[string]interface{} {
|
||||
if len(inserts) == 0 || len(allowedLower) == 0 {
|
||||
return inserts
|
||||
}
|
||||
|
||||
out := make([]map[string]interface{}, 0, len(inserts))
|
||||
for _, row := range inserts {
|
||||
if len(row) == 0 {
|
||||
out = append(out, row)
|
||||
continue
|
||||
}
|
||||
n := make(map[string]interface{}, len(row))
|
||||
for k, v := range row {
|
||||
if _, ok := allowedLower[strings.ToLower(strings.TrimSpace(k))]; ok {
|
||||
n[k] = v
|
||||
}
|
||||
}
|
||||
out = append(out, n)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterUpdateRows(updates []connection.UpdateRow, allowedLower map[string]struct{}) []connection.UpdateRow {
|
||||
if len(updates) == 0 || len(allowedLower) == 0 {
|
||||
return updates
|
||||
}
|
||||
|
||||
out := make([]connection.UpdateRow, 0, len(updates))
|
||||
for _, u := range updates {
|
||||
if len(u.Values) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
values := make(map[string]interface{}, len(u.Values))
|
||||
for k, v := range u.Values {
|
||||
if _, ok := allowedLower[strings.ToLower(strings.TrimSpace(k))]; ok {
|
||||
values[k] = v
|
||||
}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, connection.UpdateRow{
|
||||
Keys: u.Keys,
|
||||
Values: values,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeMySQLColumnType(t string) string {
|
||||
tt := strings.TrimSpace(t)
|
||||
if tt == "" {
|
||||
return "TEXT"
|
||||
}
|
||||
|
||||
// 基础防护:避免把元数据中异常内容拼进 SQL。
|
||||
if strings.ContainsAny(tt, "`;\n\r") {
|
||||
return "TEXT"
|
||||
}
|
||||
return tt
|
||||
}
|
||||
101
internal/sync/schema_sync.go
Normal file
101
internal/sync/schema_sync.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/db"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *SyncEngine) syncTableSchema(config SyncConfig, res *SyncResult, sourceDB db.Database, targetDB db.Database, tableName string) error {
|
||||
targetType := strings.ToLower(strings.TrimSpace(config.TargetConfig.Type))
|
||||
if targetType != "mysql" {
|
||||
s.appendLog(config.JobID, res, "warn", fmt.Sprintf("目标数据库类型=%s 暂不支持结构同步,已跳过表 %s", config.TargetConfig.Type, tableName))
|
||||
return nil
|
||||
}
|
||||
|
||||
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
|
||||
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
|
||||
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
|
||||
|
||||
// 1) 获取源表字段
|
||||
sourceCols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取源表字段失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 确保目标表存在
|
||||
targetCols, err := targetDB.GetColumns(targetSchema, targetTable)
|
||||
if err != nil {
|
||||
sourceType := strings.ToLower(strings.TrimSpace(config.SourceConfig.Type))
|
||||
if sourceType != "mysql" {
|
||||
return fmt.Errorf("目标表不存在且源类型=%s 暂不支持自动建表: %w", config.SourceConfig.Type, err)
|
||||
}
|
||||
|
||||
s.appendLog(config.JobID, res, "warn", fmt.Sprintf("目标表 %s 不存在,开始尝试创建表结构", tableName))
|
||||
createSQL, errCreate := sourceDB.GetCreateStatement(sourceSchema, sourceTable)
|
||||
if errCreate != nil || strings.TrimSpace(createSQL) == "" {
|
||||
if errCreate == nil {
|
||||
errCreate = fmt.Errorf("建表语句为空")
|
||||
}
|
||||
return fmt.Errorf("获取源表建表语句失败: %w", errCreate)
|
||||
}
|
||||
|
||||
if _, errExec := targetDB.Exec(createSQL); errExec != nil {
|
||||
return fmt.Errorf("创建目标表失败: %w", errExec)
|
||||
}
|
||||
s.appendLog(config.JobID, res, "info", fmt.Sprintf("目标表创建成功:%s", tableName))
|
||||
|
||||
targetCols, err = targetDB.GetColumns(targetSchema, targetTable)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建目标表后获取字段失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
targetColSet := make(map[string]struct{}, len(targetCols))
|
||||
for _, c := range targetCols {
|
||||
name := strings.ToLower(strings.TrimSpace(c.Name))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
targetColSet[name] = struct{}{}
|
||||
}
|
||||
|
||||
// 3) 补齐目标缺失字段(安全策略:新增字段统一允许 NULL)
|
||||
missing := make([]string, 0)
|
||||
sourceType := strings.ToLower(strings.TrimSpace(config.SourceConfig.Type))
|
||||
for _, c := range sourceCols {
|
||||
colName := strings.TrimSpace(c.Name)
|
||||
if colName == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(colName)
|
||||
if _, ok := targetColSet[lower]; ok {
|
||||
continue
|
||||
}
|
||||
missing = append(missing, colName)
|
||||
|
||||
colType := "TEXT"
|
||||
if sourceType == "mysql" {
|
||||
colType = sanitizeMySQLColumnType(c.Type)
|
||||
}
|
||||
|
||||
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s NULL",
|
||||
quoteQualifiedIdentByType("mysql", targetQueryTable),
|
||||
quoteIdentByType("mysql", colName),
|
||||
colType,
|
||||
)
|
||||
if _, err := targetDB.Exec(alterSQL); err != nil {
|
||||
s.appendLog(config.JobID, res, "error", fmt.Sprintf(" -> 补字段失败:表=%s 字段=%s 错误=%v", tableName, colName, err))
|
||||
continue
|
||||
}
|
||||
s.appendLog(config.JobID, res, "info", fmt.Sprintf(" -> 已补齐字段:表=%s 字段=%s 类型=%s", tableName, colName, colType))
|
||||
}
|
||||
|
||||
if len(missing) == 0 {
|
||||
s.appendLog(config.JobID, res, "info", fmt.Sprintf("表结构一致:%s", tableName))
|
||||
} else {
|
||||
s.appendLog(config.JobID, res, "info", fmt.Sprintf("表结构同步完成:%s(新增字段 %d 个)", tableName, len(missing)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
109
internal/sync/sql_helpers.go
Normal file
109
internal/sync/sql_helpers.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package sync
|
||||
|
||||
import "strings"
|
||||
|
||||
func normalizeSyncMode(mode string) string {
|
||||
m := strings.ToLower(strings.TrimSpace(mode))
|
||||
switch m {
|
||||
case "", "insert_update":
|
||||
return "insert_update"
|
||||
case "insert_only":
|
||||
return "insert_only"
|
||||
case "full_overwrite":
|
||||
return "full_overwrite"
|
||||
default:
|
||||
return "insert_update"
|
||||
}
|
||||
}
|
||||
|
||||
func quoteIdentByType(dbType string, ident string) string {
|
||||
if ident == "" {
|
||||
return ident
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
return "`" + strings.ReplaceAll(ident, "`", "``") + "`"
|
||||
default:
|
||||
return `"` + strings.ReplaceAll(ident, `"`, `""`) + `"`
|
||||
}
|
||||
}
|
||||
|
||||
func quoteQualifiedIdentByType(dbType string, ident string) string {
|
||||
raw := strings.TrimSpace(ident)
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, ".")
|
||||
if len(parts) <= 1 {
|
||||
return quoteIdentByType(dbType, raw)
|
||||
}
|
||||
|
||||
quotedParts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
quotedParts = append(quotedParts, quoteIdentByType(dbType, part))
|
||||
}
|
||||
|
||||
if len(quotedParts) == 0 {
|
||||
return quoteIdentByType(dbType, raw)
|
||||
}
|
||||
return strings.Join(quotedParts, ".")
|
||||
}
|
||||
|
||||
func normalizeSchemaAndTable(dbType string, dbName string, tableName string) (string, string) {
|
||||
rawTable := strings.TrimSpace(tableName)
|
||||
rawDB := strings.TrimSpace(dbName)
|
||||
if rawTable == "" {
|
||||
return rawDB, rawTable
|
||||
}
|
||||
|
||||
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
|
||||
schema := strings.TrimSpace(parts[0])
|
||||
table := strings.TrimSpace(parts[1])
|
||||
if schema != "" && table != "" {
|
||||
return schema, table
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase":
|
||||
return "public", rawTable
|
||||
default:
|
||||
return rawDB, rawTable
|
||||
}
|
||||
}
|
||||
|
||||
func qualifiedNameForQuery(dbType string, schema string, table string, original string) string {
|
||||
raw := strings.TrimSpace(original)
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
if strings.Contains(raw, ".") {
|
||||
return raw
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "kingbase":
|
||||
s := strings.TrimSpace(schema)
|
||||
if s == "" {
|
||||
s = "public"
|
||||
}
|
||||
if table == "" {
|
||||
return raw
|
||||
}
|
||||
return s + "." + table
|
||||
case "mysql":
|
||||
s := strings.TrimSpace(schema)
|
||||
if s == "" || table == "" {
|
||||
return table
|
||||
}
|
||||
return s + "." + table
|
||||
default:
|
||||
return table
|
||||
}
|
||||
}
|
||||
@@ -5,15 +5,21 @@ import (
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SyncConfig defines the parameters for a synchronization task
|
||||
type SyncConfig struct {
|
||||
SourceConfig connection.ConnectionConfig `json:"sourceConfig"`
|
||||
TargetConfig connection.ConnectionConfig `json:"targetConfig"`
|
||||
Tables []string `json:"tables"` // Tables to sync
|
||||
Mode string `json:"mode"` // "insert_update", "full_overwrite"
|
||||
SourceConfig connection.ConnectionConfig `json:"sourceConfig"`
|
||||
TargetConfig connection.ConnectionConfig `json:"targetConfig"`
|
||||
Tables []string `json:"tables"` // Tables to sync
|
||||
Content string `json:"content,omitempty"` // "data", "schema", "both"
|
||||
Mode string `json:"mode"` // "insert_update", "insert_only", "full_overwrite"
|
||||
JobID string `json:"jobId,omitempty"`
|
||||
AutoAddColumns bool `json:"autoAddColumns,omitempty"` // 自动补齐缺失字段(当前仅 MySQL 目标支持)
|
||||
TableOptions map[string]TableOptions `json:"tableOptions,omitempty"`
|
||||
}
|
||||
|
||||
// SyncResult holds the result of the sync operation
|
||||
@@ -28,21 +34,55 @@ type SyncResult struct {
|
||||
}
|
||||
|
||||
type SyncEngine struct {
|
||||
reporter Reporter
|
||||
}
|
||||
|
||||
func NewSyncEngine() *SyncEngine {
|
||||
return &SyncEngine{}
|
||||
func NewSyncEngine(reporter Reporter) *SyncEngine {
|
||||
return &SyncEngine{reporter: reporter}
|
||||
}
|
||||
|
||||
// CompareAndSync performs the synchronization
|
||||
func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
result := SyncResult{Success: true, Logs: []string{}}
|
||||
logger.Infof("开始数据同步:源=%s 目标=%s 表数量=%d", formatConnSummaryForSync(config.SourceConfig), formatConnSummaryForSync(config.TargetConfig), len(config.Tables))
|
||||
totalTables := len(config.Tables)
|
||||
s.progress(config.JobID, 0, totalTables, "", "开始同步")
|
||||
|
||||
contentRaw := strings.ToLower(strings.TrimSpace(config.Content))
|
||||
syncSchema := false
|
||||
syncData := true
|
||||
switch contentRaw {
|
||||
case "", "data":
|
||||
syncData = true
|
||||
case "schema":
|
||||
syncSchema = true
|
||||
syncData = false
|
||||
case "both":
|
||||
syncSchema = true
|
||||
syncData = true
|
||||
default:
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("未知同步内容 %q,已自动使用仅同步数据", config.Content))
|
||||
syncData = true
|
||||
}
|
||||
|
||||
modeRaw := strings.ToLower(strings.TrimSpace(config.Mode))
|
||||
if modeRaw != "" && modeRaw != "insert_update" && modeRaw != "insert_only" && modeRaw != "full_overwrite" {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("未知同步模式 %q,已自动使用 insert_update", config.Mode))
|
||||
}
|
||||
defaultMode := normalizeSyncMode(config.Mode)
|
||||
|
||||
contentLabel := "仅同步数据"
|
||||
if syncSchema && syncData {
|
||||
contentLabel = "同步结构+数据"
|
||||
} else if syncSchema {
|
||||
contentLabel = "仅同步结构"
|
||||
}
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("同步内容:%s;模式:%s;自动补字段:%v", contentLabel, defaultMode, config.AutoAddColumns))
|
||||
|
||||
sourceDB, err := db.NewDatabase(config.SourceConfig.Type)
|
||||
if err != nil {
|
||||
logger.Error(err, "初始化源数据库驱动失败:类型=%s", config.SourceConfig.Type)
|
||||
return s.fail(result, "初始化源数据库驱动失败: "+err.Error())
|
||||
return s.fail(config.JobID, totalTables, result, "初始化源数据库驱动失败: "+err.Error())
|
||||
}
|
||||
if config.SourceConfig.Type == "custom" {
|
||||
// Custom DB setup would go here if needed
|
||||
@@ -51,133 +91,402 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
targetDB, err := db.NewDatabase(config.TargetConfig.Type)
|
||||
if err != nil {
|
||||
logger.Error(err, "初始化目标数据库驱动失败:类型=%s", config.TargetConfig.Type)
|
||||
return s.fail(result, "初始化目标数据库驱动失败: "+err.Error())
|
||||
return s.fail(config.JobID, totalTables, result, "初始化目标数据库驱动失败: "+err.Error())
|
||||
}
|
||||
|
||||
// Connect Source
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("正在连接源数据库: %s...", config.SourceConfig.Host))
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("正在连接源数据库: %s...", config.SourceConfig.Host))
|
||||
s.progress(config.JobID, 0, totalTables, "", "连接源数据库")
|
||||
if err := sourceDB.Connect(config.SourceConfig); err != nil {
|
||||
logger.Error(err, "源数据库连接失败:%s", formatConnSummaryForSync(config.SourceConfig))
|
||||
return s.fail(result, "源数据库连接失败: "+err.Error())
|
||||
return s.fail(config.JobID, totalTables, result, "源数据库连接失败: "+err.Error())
|
||||
}
|
||||
defer sourceDB.Close()
|
||||
|
||||
// Connect Target
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("正在连接目标数据库: %s...", config.TargetConfig.Host))
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("正在连接目标数据库: %s...", config.TargetConfig.Host))
|
||||
s.progress(config.JobID, 0, totalTables, "", "连接目标数据库")
|
||||
if err := targetDB.Connect(config.TargetConfig); err != nil {
|
||||
logger.Error(err, "目标数据库连接失败:%s", formatConnSummaryForSync(config.TargetConfig))
|
||||
return s.fail(result, "目标数据库连接失败: "+err.Error())
|
||||
return s.fail(config.JobID, totalTables, result, "目标数据库连接失败: "+err.Error())
|
||||
}
|
||||
defer targetDB.Close()
|
||||
|
||||
// Iterate Tables
|
||||
for _, tableName := range config.Tables {
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("正在同步表: %s", tableName))
|
||||
for i, tableName := range config.Tables {
|
||||
func() {
|
||||
tableMode := defaultMode
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("正在同步表: %s", tableName))
|
||||
s.progress(config.JobID, i, totalTables, tableName, fmt.Sprintf("同步表(%d/%d)", i+1, totalTables))
|
||||
defer s.progress(config.JobID, i+1, totalTables, tableName, "表处理完成")
|
||||
|
||||
// 1. Get Columns & PKs (Naive approach: assume same schema)
|
||||
cols, err := sourceDB.GetColumns(config.SourceConfig.Database, tableName)
|
||||
if err != nil {
|
||||
logger.Error(err, "获取源表列信息失败:表=%s", tableName)
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("获取表 %s 的列信息失败: %v", tableName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
pkCol := ""
|
||||
for _, col := range cols {
|
||||
if col.Key == "PRI" || col.Key == "PK" {
|
||||
pkCol = col.Name
|
||||
break
|
||||
if syncSchema {
|
||||
s.progress(config.JobID, i, totalTables, tableName, "同步表结构")
|
||||
if err := s.syncTableSchema(config, &result, sourceDB, targetDB, tableName); err != nil {
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("表结构同步失败:表=%s 错误=%v", tableName, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if !syncData {
|
||||
result.TablesSynced++
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if pkCol == "" {
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("跳过表 %s: 未找到主键 (同步需要主键)", tableName))
|
||||
continue
|
||||
}
|
||||
sourceSchema, sourceTable := normalizeSchemaAndTable(config.SourceConfig.Type, config.SourceConfig.Database, tableName)
|
||||
targetSchema, targetTable := normalizeSchemaAndTable(config.TargetConfig.Type, config.TargetConfig.Database, tableName)
|
||||
sourceQueryTable := qualifiedNameForQuery(config.SourceConfig.Type, sourceSchema, sourceTable, tableName)
|
||||
targetQueryTable := qualifiedNameForQuery(config.TargetConfig.Type, targetSchema, targetTable, tableName)
|
||||
|
||||
// 2. Fetch Data (MEMORY INTENSIVE - PROTOTYPE ONLY)
|
||||
// TODO: Implement paging/streaming
|
||||
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", tableName))
|
||||
if err != nil {
|
||||
logger.Error(err, "读取源表失败:表=%s", tableName)
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("读取源表 %s 失败: %v", tableName, err))
|
||||
continue
|
||||
}
|
||||
// 1. Get Columns & PKs
|
||||
cols, err := sourceDB.GetColumns(sourceSchema, sourceTable)
|
||||
if err != nil {
|
||||
logger.Error(err, "获取源表列信息失败:表=%s", tableName)
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("获取表 %s 的列信息失败: %v", tableName, err))
|
||||
return
|
||||
}
|
||||
sourceColsByLower := make(map[string]connection.ColumnDefinition, len(cols))
|
||||
for _, col := range cols {
|
||||
if strings.TrimSpace(col.Name) == "" {
|
||||
continue
|
||||
}
|
||||
sourceColsByLower[strings.ToLower(strings.TrimSpace(col.Name))] = col
|
||||
}
|
||||
|
||||
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", tableName))
|
||||
if err != nil {
|
||||
logger.Error(err, "读取目标表失败:表=%s", tableName)
|
||||
// Table might not exist in target?
|
||||
// Check if error is "table not found" -> Try to Create?
|
||||
// For now, assume table exists.
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("读取目标表 %s 失败: %v", tableName, err))
|
||||
continue
|
||||
}
|
||||
pkCols := make([]string, 0, 2)
|
||||
for _, col := range cols {
|
||||
if col.Key == "PRI" || col.Key == "PK" {
|
||||
pkCols = append(pkCols, col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Compare (In-Memory Hash Map)
|
||||
targetMap := make(map[string]map[string]interface{})
|
||||
for _, row := range targetRows {
|
||||
pkVal := fmt.Sprintf("%v", row[pkCol])
|
||||
targetMap[pkVal] = row
|
||||
}
|
||||
if len(pkCols) == 0 {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("表 %s 未找到主键,已跳过数据同步(避免产生重复数据)", tableName))
|
||||
return
|
||||
}
|
||||
if len(pkCols) > 1 {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf("表 %s 为复合主键(%s),当前暂不支持数据同步", tableName, strings.Join(pkCols, ",")))
|
||||
return
|
||||
}
|
||||
pkCol := pkCols[0]
|
||||
|
||||
var inserts []map[string]interface{}
|
||||
var updates []connection.UpdateRow
|
||||
// var deletes []map[string]interface{} // Not implemented in "insert_update" mode usually
|
||||
|
||||
for _, sRow := range sourceRows {
|
||||
pkVal := fmt.Sprintf("%v", sRow[pkCol])
|
||||
|
||||
if tRow, exists := targetMap[pkVal]; exists {
|
||||
// Update? Compare values
|
||||
// Simplified: Compare string representations or iterate keys
|
||||
// For prototype: assume update if exists
|
||||
// Optimization: Check diff
|
||||
changes := make(map[string]interface{})
|
||||
for k, v := range sRow {
|
||||
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
|
||||
changes[k] = v
|
||||
opts := TableOptions{Insert: true, Update: true, Delete: false}
|
||||
if config.TableOptions != nil {
|
||||
if t, ok := config.TableOptions[tableName]; ok {
|
||||
opts = t
|
||||
// 默认防护:如用户未设置任意一个字段,保持 insert/update 默认 true、delete 默认 false
|
||||
if !t.Insert && !t.Update && !t.Delete {
|
||||
opts = t
|
||||
}
|
||||
}
|
||||
if len(changes) > 0 {
|
||||
updates = append(updates, connection.UpdateRow{
|
||||
Keys: map[string]interface{}{pkCol: pkVal},
|
||||
Values: changes,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Insert
|
||||
inserts = append(inserts, sRow)
|
||||
}
|
||||
}
|
||||
if !opts.Insert && !opts.Update && !opts.Delete {
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf("表 %s 未勾选任何操作,已跳过", tableName))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Apply Changes
|
||||
changeSet := connection.ChangeSet{
|
||||
Inserts: inserts,
|
||||
Updates: updates,
|
||||
}
|
||||
// 2. Fetch Data (MEMORY INTENSIVE - PROTOTYPE ONLY)
|
||||
// TODO: Implement paging/streaming
|
||||
s.progress(config.JobID, i, totalTables, tableName, "读取源表数据")
|
||||
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.SourceConfig.Type, sourceQueryTable)))
|
||||
if err != nil {
|
||||
logger.Error(err, "读取源表失败:表=%s", tableName)
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("读取源表 %s 失败: %v", tableName, err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(inserts) > 0 || len(updates) > 0 {
|
||||
result.Logs = append(result.Logs, fmt.Sprintf(" -> 需插入: %d 行, 需更新: %d 行", len(inserts), len(updates)))
|
||||
var inserts []map[string]interface{}
|
||||
var updates []connection.UpdateRow
|
||||
|
||||
// We need a BatchApplier interface or assume Database implements ApplyChanges
|
||||
if applier, ok := targetDB.(db.BatchApplier); ok {
|
||||
if err := applier.ApplyChanges(tableName, changeSet); err != nil {
|
||||
result.Logs = append(result.Logs, fmt.Sprintf(" -> 应用变更失败: %v", err))
|
||||
if tableMode == "insert_update" {
|
||||
s.progress(config.JobID, i, totalTables, tableName, "读取目标表数据")
|
||||
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable)))
|
||||
if err != nil {
|
||||
logger.Error(err, "读取目标表失败:表=%s", tableName)
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf("读取目标表 %s 失败: %v", tableName, err))
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Compare (In-Memory Hash Map)
|
||||
s.progress(config.JobID, i, totalTables, tableName, "对比差异")
|
||||
targetMap := make(map[string]map[string]interface{})
|
||||
for _, row := range targetRows {
|
||||
if row[pkCol] == nil {
|
||||
continue
|
||||
}
|
||||
pkVal := fmt.Sprintf("%v", row[pkCol])
|
||||
if strings.TrimSpace(pkVal) == "" || pkVal == "<nil>" {
|
||||
continue
|
||||
}
|
||||
targetMap[pkVal] = row
|
||||
}
|
||||
sourcePKSet := make(map[string]struct{}, len(sourceRows))
|
||||
|
||||
for _, sRow := range sourceRows {
|
||||
if sRow[pkCol] == nil {
|
||||
continue
|
||||
}
|
||||
pkVal := fmt.Sprintf("%v", sRow[pkCol])
|
||||
if strings.TrimSpace(pkVal) == "" || pkVal == "<nil>" {
|
||||
continue
|
||||
}
|
||||
sourcePKSet[pkVal] = struct{}{}
|
||||
|
||||
if tRow, exists := targetMap[pkVal]; exists {
|
||||
changes := make(map[string]interface{})
|
||||
for k, v := range sRow {
|
||||
if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", tRow[k]) {
|
||||
changes[k] = v
|
||||
}
|
||||
}
|
||||
if len(changes) > 0 {
|
||||
updates = append(updates, connection.UpdateRow{
|
||||
Keys: map[string]interface{}{pkCol: sRow[pkCol]},
|
||||
Values: changes,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
inserts = append(inserts, sRow)
|
||||
}
|
||||
}
|
||||
|
||||
var deletes []map[string]interface{}
|
||||
if opts.Delete {
|
||||
for pkStr, row := range targetMap {
|
||||
if _, ok := sourcePKSet[pkStr]; ok {
|
||||
continue
|
||||
}
|
||||
deletes = append(deletes, map[string]interface{}{pkCol: row[pkCol]})
|
||||
}
|
||||
}
|
||||
|
||||
// apply operation selection
|
||||
inserts = filterRowsByPKSelection(pkCol, inserts, opts.Insert, opts.SelectedInsertPKs)
|
||||
updates = filterUpdatesByPKSelection(pkCol, updates, opts.Update, opts.SelectedUpdatePKs)
|
||||
deletes = filterRowsByPKSelection(pkCol, deletes, opts.Delete, opts.SelectedDeletePKs)
|
||||
|
||||
changeSet := connection.ChangeSet{
|
||||
Inserts: inserts,
|
||||
Updates: updates,
|
||||
Deletes: deletes,
|
||||
}
|
||||
|
||||
// 4. Align schema (target missing columns)
|
||||
s.progress(config.JobID, i, totalTables, tableName, "检查字段一致性")
|
||||
requiredCols := collectRequiredColumns(changeSet.Inserts, changeSet.Updates)
|
||||
targetCols, err := targetDB.GetColumns(targetSchema, targetTable)
|
||||
if err != nil {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 获取目标表字段失败,已跳过字段一致性检查: %v", err))
|
||||
} else {
|
||||
result.RowsInserted += len(inserts)
|
||||
result.RowsUpdated += len(updates)
|
||||
targetColSet := make(map[string]struct{}, len(targetCols))
|
||||
for _, c := range targetCols {
|
||||
name := strings.ToLower(strings.TrimSpace(c.Name))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
targetColSet[name] = struct{}{}
|
||||
}
|
||||
|
||||
missing := make([]string, 0)
|
||||
for lower, original := range requiredCols {
|
||||
if _, ok := targetColSet[lower]; !ok {
|
||||
missing = append(missing, original)
|
||||
}
|
||||
}
|
||||
sort.Strings(missing)
|
||||
|
||||
if len(missing) > 0 {
|
||||
if config.AutoAddColumns && strings.ToLower(strings.TrimSpace(config.TargetConfig.Type)) == "mysql" {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个,开始自动补齐: %s", len(missing), strings.Join(missing, ", ")))
|
||||
added := 0
|
||||
for _, colName := range missing {
|
||||
colLower := strings.ToLower(strings.TrimSpace(colName))
|
||||
colType := "TEXT"
|
||||
if strings.ToLower(strings.TrimSpace(config.SourceConfig.Type)) == "mysql" {
|
||||
if srcCol, ok := sourceColsByLower[colLower]; ok {
|
||||
colType = sanitizeMySQLColumnType(srcCol.Type)
|
||||
}
|
||||
}
|
||||
|
||||
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s NULL",
|
||||
quoteQualifiedIdentByType("mysql", targetQueryTable),
|
||||
quoteIdentByType("mysql", colName),
|
||||
colType,
|
||||
)
|
||||
if _, err := targetDB.Exec(alterSQL); err != nil {
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 自动补字段失败:字段=%s 错误=%v", colName, err))
|
||||
continue
|
||||
}
|
||||
added++
|
||||
}
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 自动补字段完成:成功=%d 失败=%d", added, len(missing)-added))
|
||||
|
||||
// refresh columns
|
||||
targetCols, err = targetDB.GetColumns(targetSchema, targetTable)
|
||||
if err == nil {
|
||||
targetColSet = make(map[string]struct{}, len(targetCols))
|
||||
for _, c := range targetCols {
|
||||
name := strings.ToLower(strings.TrimSpace(c.Name))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
targetColSet[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个(未开启自动补齐),将自动忽略:%s", len(missing), strings.Join(missing, ", ")))
|
||||
}
|
||||
|
||||
// filter out still-missing columns to avoid apply failure
|
||||
changeSet.Inserts = filterInsertRows(changeSet.Inserts, targetColSet)
|
||||
changeSet.Updates = filterUpdateRows(changeSet.Updates, targetColSet)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Apply Changes
|
||||
s.progress(config.JobID, i, totalTables, tableName, "应用变更")
|
||||
|
||||
if len(changeSet.Inserts) > 0 || len(changeSet.Updates) > 0 || len(changeSet.Deletes) > 0 {
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 需插入: %d 行, 需更新: %d 行, 需删除: %d 行", len(changeSet.Inserts), len(changeSet.Updates), len(changeSet.Deletes)))
|
||||
|
||||
if applier, ok := targetDB.(db.BatchApplier); ok {
|
||||
if err := applier.ApplyChanges(targetTable, changeSet); err != nil {
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 应用变更失败: %v", err))
|
||||
} else {
|
||||
result.RowsInserted += len(changeSet.Inserts)
|
||||
result.RowsUpdated += len(changeSet.Updates)
|
||||
result.RowsDeleted += len(changeSet.Deletes)
|
||||
}
|
||||
} else {
|
||||
s.appendLog(config.JobID, &result, "warn", " -> 目标驱动不支持应用数据变更 (ApplyChanges).")
|
||||
}
|
||||
} else {
|
||||
s.appendLog(config.JobID, &result, "info", " -> 数据一致,无需变更.")
|
||||
}
|
||||
|
||||
result.TablesSynced++
|
||||
return
|
||||
} else {
|
||||
// insert_only / full_overwrite: do not compare target, just insert source rows
|
||||
inserts = sourceRows
|
||||
}
|
||||
|
||||
// full_overwrite: clear target table first
|
||||
if tableMode == "full_overwrite" {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 全量覆盖模式:即将清空目标表 %s", tableName))
|
||||
s.progress(config.JobID, i, totalTables, tableName, "清空目标表")
|
||||
clearSQL := ""
|
||||
if strings.ToLower(strings.TrimSpace(config.TargetConfig.Type)) == "mysql" {
|
||||
clearSQL = fmt.Sprintf("TRUNCATE TABLE %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable))
|
||||
} else {
|
||||
clearSQL = fmt.Sprintf("DELETE FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, targetQueryTable))
|
||||
}
|
||||
if _, err := targetDB.Exec(clearSQL); err != nil {
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 清空目标表失败: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Align schema (target missing columns)
|
||||
s.progress(config.JobID, i, totalTables, tableName, "检查字段一致性")
|
||||
requiredCols := collectRequiredColumns(inserts, updates)
|
||||
targetCols, err := targetDB.GetColumns(targetSchema, targetTable)
|
||||
if err != nil {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 获取目标表字段失败,已跳过字段一致性检查: %v", err))
|
||||
} else {
|
||||
targetColSet := make(map[string]struct{}, len(targetCols))
|
||||
for _, c := range targetCols {
|
||||
name := strings.ToLower(strings.TrimSpace(c.Name))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
targetColSet[name] = struct{}{}
|
||||
}
|
||||
|
||||
missing := make([]string, 0)
|
||||
for lower, original := range requiredCols {
|
||||
if _, ok := targetColSet[lower]; !ok {
|
||||
missing = append(missing, original)
|
||||
}
|
||||
}
|
||||
sort.Strings(missing)
|
||||
|
||||
if len(missing) > 0 {
|
||||
if config.AutoAddColumns && strings.ToLower(strings.TrimSpace(config.TargetConfig.Type)) == "mysql" {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个,开始自动补齐: %s", len(missing), strings.Join(missing, ", ")))
|
||||
added := 0
|
||||
for _, colName := range missing {
|
||||
colLower := strings.ToLower(strings.TrimSpace(colName))
|
||||
colType := "TEXT"
|
||||
if strings.ToLower(strings.TrimSpace(config.SourceConfig.Type)) == "mysql" {
|
||||
if srcCol, ok := sourceColsByLower[colLower]; ok {
|
||||
colType = sanitizeMySQLColumnType(srcCol.Type)
|
||||
}
|
||||
}
|
||||
|
||||
alterSQL := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s NULL",
|
||||
quoteQualifiedIdentByType("mysql", targetQueryTable),
|
||||
quoteIdentByType("mysql", colName),
|
||||
colType,
|
||||
)
|
||||
if _, err := targetDB.Exec(alterSQL); err != nil {
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 自动补字段失败:字段=%s 错误=%v", colName, err))
|
||||
continue
|
||||
}
|
||||
added++
|
||||
}
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 自动补字段完成:成功=%d 失败=%d", added, len(missing)-added))
|
||||
|
||||
// refresh columns
|
||||
targetCols, err = targetDB.GetColumns(targetSchema, targetTable)
|
||||
if err == nil {
|
||||
targetColSet = make(map[string]struct{}, len(targetCols))
|
||||
for _, c := range targetCols {
|
||||
name := strings.ToLower(strings.TrimSpace(c.Name))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
targetColSet[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.appendLog(config.JobID, &result, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个(未开启自动补齐),将自动忽略:%s", len(missing), strings.Join(missing, ", ")))
|
||||
}
|
||||
|
||||
// filter out still-missing columns to avoid apply failure
|
||||
inserts = filterInsertRows(inserts, targetColSet)
|
||||
updates = filterUpdateRows(updates, targetColSet)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Apply Changes
|
||||
s.progress(config.JobID, i, totalTables, tableName, "应用变更")
|
||||
changeSet := connection.ChangeSet{
|
||||
Inserts: inserts,
|
||||
Updates: updates,
|
||||
}
|
||||
|
||||
if len(changeSet.Inserts) > 0 || len(changeSet.Updates) > 0 {
|
||||
s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 需插入: %d 行, 需更新: %d 行", len(changeSet.Inserts), len(changeSet.Updates)))
|
||||
|
||||
if applier, ok := targetDB.(db.BatchApplier); ok {
|
||||
if err := applier.ApplyChanges(targetTable, changeSet); err != nil {
|
||||
s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 应用变更失败: %v", err))
|
||||
} else {
|
||||
result.RowsInserted += len(changeSet.Inserts)
|
||||
result.RowsUpdated += len(changeSet.Updates)
|
||||
}
|
||||
} else {
|
||||
s.appendLog(config.JobID, &result, "warn", " -> 目标驱动不支持应用数据变更 (ApplyChanges).")
|
||||
}
|
||||
} else {
|
||||
result.Logs = append(result.Logs, " -> 目标驱动不支持应用数据变更 (ApplyChanges).")
|
||||
s.appendLog(config.JobID, &result, "info", " -> 数据一致,无需变更.")
|
||||
}
|
||||
} else {
|
||||
result.Logs = append(result.Logs, " -> 数据一致,无需变更.")
|
||||
}
|
||||
|
||||
result.TablesSynced++
|
||||
result.TablesSynced++
|
||||
}()
|
||||
}
|
||||
|
||||
s.progress(config.JobID, totalTables, totalTables, "", "同步完成")
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -196,9 +505,52 @@ func formatConnSummaryForSync(config connection.ConnectionConfig) string {
|
||||
config.Type, config.Host, config.Port, dbName, config.User, timeoutSeconds)
|
||||
}
|
||||
|
||||
func (s *SyncEngine) fail(res SyncResult, msg string) SyncResult {
|
||||
func (s *SyncEngine) appendLog(jobID string, res *SyncResult, level string, msg string) {
|
||||
if res != nil {
|
||||
res.Logs = append(res.Logs, msg)
|
||||
}
|
||||
if s.reporter.OnLog != nil && strings.TrimSpace(jobID) != "" {
|
||||
s.reporter.OnLog(SyncLogEvent{
|
||||
JobID: jobID,
|
||||
Level: level,
|
||||
Message: msg,
|
||||
Ts: time.Now().UnixMilli(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncEngine) progress(jobID string, current, total int, table string, stage string) {
|
||||
if s.reporter.OnProgress == nil || strings.TrimSpace(jobID) == "" {
|
||||
return
|
||||
}
|
||||
percent := 0
|
||||
if total <= 0 {
|
||||
if current > 0 {
|
||||
percent = 100
|
||||
}
|
||||
} else {
|
||||
if current < 0 {
|
||||
current = 0
|
||||
}
|
||||
if current > total {
|
||||
current = total
|
||||
}
|
||||
percent = (current * 100) / total
|
||||
}
|
||||
s.reporter.OnProgress(SyncProgressEvent{
|
||||
JobID: jobID,
|
||||
Percent: percent,
|
||||
Current: current,
|
||||
Total: total,
|
||||
Table: table,
|
||||
Stage: stage,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SyncEngine) fail(jobID string, totalTables int, res SyncResult, msg string) SyncResult {
|
||||
res.Success = false
|
||||
res.Message = msg
|
||||
res.Logs = append(res.Logs, "致命错误: "+msg)
|
||||
s.appendLog(jobID, &res, "error", "致命错误: "+msg)
|
||||
s.progress(jobID, res.TablesSynced, totalTables, "", "同步失败")
|
||||
return res
|
||||
}
|
||||
|
||||
30
internal/sync/sync_events.go
Normal file
30
internal/sync/sync_events.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package sync
|
||||
|
||||
const (
|
||||
EventSyncStart = "sync:start"
|
||||
EventSyncProgress = "sync:progress"
|
||||
EventSyncLog = "sync:log"
|
||||
EventSyncDone = "sync:done"
|
||||
)
|
||||
|
||||
type SyncLogEvent struct {
|
||||
JobID string `json:"jobId"`
|
||||
Level string `json:"level"` // info/warn/error
|
||||
Message string `json:"message"`
|
||||
Ts int64 `json:"ts"` // Unix milli
|
||||
}
|
||||
|
||||
type SyncProgressEvent struct {
|
||||
JobID string `json:"jobId"`
|
||||
Percent int `json:"percent"`
|
||||
Current int `json:"current"` // 已完成表数
|
||||
Total int `json:"total"` // 总表数
|
||||
Table string `json:"table,omitempty"`
|
||||
Stage string `json:"stage,omitempty"`
|
||||
}
|
||||
|
||||
type Reporter struct {
|
||||
OnLog func(event SyncLogEvent)
|
||||
OnProgress func(event SyncProgressEvent)
|
||||
}
|
||||
|
||||
13
internal/sync/table_options.go
Normal file
13
internal/sync/table_options.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package sync
|
||||
|
||||
// TableOptions controls which operations to apply per table, and optional row selection.
|
||||
// 注意:如未指定 Selected*PKs,则表示“同步全部该类型差异数据”;如指定为空数组,则同样表示全部。
|
||||
type TableOptions struct {
|
||||
Insert bool `json:"insert,omitempty"`
|
||||
Update bool `json:"update,omitempty"`
|
||||
Delete bool `json:"delete,omitempty"`
|
||||
|
||||
SelectedInsertPKs []string `json:"selectedInsertPks,omitempty"`
|
||||
SelectedUpdatePKs []string `json:"selectedUpdatePks,omitempty"`
|
||||
SelectedDeletePKs []string `json:"selectedDeletePks,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user