mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-11 18:39:48 +08:00
Merge pull request #35 from Syngnat/release/0.2.1
- 前端改用通用 DB API,避免强制走 MySQL 接口导致 PostgreSQL 等连接异常 - 后端统一各数据源 timeout(Ping 超时 + 连接参数注入) - DSN 生成兼容特殊字符密码(Postgres/Oracle/达梦/金仓) - 增加文件日志与错误链输出,连接失败提示日志路径便于排障
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -15,4 +15,5 @@ dist/
|
||||
.DS_Store
|
||||
.gemini-clipboard
|
||||
GoNavi-Wails
|
||||
GoNavi-Wails.exe
|
||||
GoNavi-Wails.exe
|
||||
.ace-tool/
|
||||
|
||||
8
frontend/package-lock.json
generated
8
frontend/package-lock.json
generated
@@ -27,6 +27,7 @@
|
||||
"@types/react": "^18.2.43",
|
||||
"@types/react-dom": "^18.2.17",
|
||||
"@types/react-resizable": "^3.0.8",
|
||||
"@types/uuid": "^9.0.7",
|
||||
"@vitejs/plugin-react": "^4.2.1",
|
||||
"typescript": "^5.2.2",
|
||||
"vite": "^5.0.8"
|
||||
@@ -1565,6 +1566,13 @@
|
||||
"optional": true,
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/@types/uuid": {
|
||||
"version": "9.0.8",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz",
|
||||
"integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@vitejs/plugin-react": {
|
||||
"version": "4.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.7.0.tgz",
|
||||
|
||||
@@ -1 +1 @@
|
||||
c1af19c07654ec9f98628c358ae49b1a
|
||||
5b8157374dae5f9340e31b2d0bd2c00e
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal, Form, Input, InputNumber, Button, message, Checkbox, Divider, Select, Alert, Card, Row, Col, Typography } from 'antd';
|
||||
import { Modal, Form, Input, InputNumber, Button, message, Checkbox, Divider, Select, Alert, Card, Row, Col, Typography, Collapse } from 'antd';
|
||||
import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { MySQLConnect, MySQLGetDatabases } from '../../wailsjs/go/app/App';
|
||||
import { DBConnect, DBGetDatabases, TestConnection } from '../../wailsjs/go/app/App';
|
||||
import { SavedConnection } from '../types';
|
||||
|
||||
const { Meta } = Card;
|
||||
@@ -42,7 +42,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
sshPassword: initialValues.config.ssh?.password,
|
||||
sshKeyPath: initialValues.config.ssh?.keyPath,
|
||||
driver: (initialValues.config as any).driver,
|
||||
dsn: (initialValues.config as any).dsn
|
||||
dsn: (initialValues.config as any).dsn,
|
||||
timeout: (initialValues.config as any).timeout || 30
|
||||
});
|
||||
setUseSSH(initialValues.config.useSSH || false);
|
||||
setDbType(initialValues.config.type);
|
||||
@@ -63,7 +64,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
|
||||
const config = await buildConfig(values);
|
||||
|
||||
const res = await MySQLConnect(config as any);
|
||||
const res = await DBConnect(config as any);
|
||||
setLoading(false);
|
||||
|
||||
if (res.success) {
|
||||
@@ -101,11 +102,11 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
setLoading(true);
|
||||
setTestResult(null);
|
||||
const config = await buildConfig(values);
|
||||
const res = await (window as any).go.app.App.TestConnection(config);
|
||||
const res = await TestConnection(config as any);
|
||||
setLoading(false);
|
||||
if (res.success) {
|
||||
setTestResult({ type: 'success', message: res.message });
|
||||
const dbRes = await MySQLGetDatabases(config as any);
|
||||
const dbRes = await DBGetDatabases(config as any);
|
||||
if (dbRes.success) {
|
||||
const dbs = (dbRes.data as any[]).map((row: any) => row.Database || row.database);
|
||||
setDbList(dbs);
|
||||
@@ -137,7 +138,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
useSSH: !!values.useSSH,
|
||||
ssh: sshConfig,
|
||||
driver: values.driver,
|
||||
dsn: values.dsn
|
||||
dsn: values.dsn,
|
||||
timeout: Number(values.timeout || 30)
|
||||
};
|
||||
};
|
||||
|
||||
@@ -196,7 +198,7 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
initialValues={{ type: 'mysql', host: 'localhost', port: 3306, user: 'root', useSSH: false, sshPort: 22 }}
|
||||
initialValues={{ type: 'mysql', host: 'localhost', port: 3306, user: 'root', useSSH: false, sshPort: 22, timeout: 30 }}
|
||||
onValuesChange={(changed) => {
|
||||
if (testResult) setTestResult(null); // Clear result on change
|
||||
if (changed.useSSH !== undefined) setUseSSH(changed.useSSH);
|
||||
@@ -282,6 +284,26 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
</Form.Item>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Divider style={{ margin: '12px 0' }} />
|
||||
|
||||
<Collapse
|
||||
ghost
|
||||
items={[{
|
||||
key: 'advanced',
|
||||
label: '高级连接',
|
||||
children: (
|
||||
<Form.Item
|
||||
name="timeout"
|
||||
label="连接超时 (秒)"
|
||||
help="数据库连接超时时间,默认 30 秒"
|
||||
rules={[{ type: 'number', min: 1, max: 300, message: '超时时间范围: 1-300 秒' }]}
|
||||
>
|
||||
<InputNumber style={{ width: '100%' }} min={1} max={300} placeholder="30" />
|
||||
</Form.Item>
|
||||
)
|
||||
}]}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
@@ -334,4 +356,4 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
|
||||
);
|
||||
};
|
||||
|
||||
export default ConnectionModal;
|
||||
export default ConnectionModal;
|
||||
|
||||
@@ -2,7 +2,7 @@ import React, { useEffect, useState, useCallback } from 'react';
|
||||
import { message } from 'antd';
|
||||
import { TabData, ColumnDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { MySQLQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import DataGrid from './DataGrid';
|
||||
|
||||
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
@@ -41,6 +41,13 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const quoteIdent = (ident: string) => {
|
||||
if (!ident) return ident;
|
||||
if (config.type === 'mysql') return `\`${ident.replace(/`/g, '``')}\``;
|
||||
return `"${ident.replace(/"/g, '""')}"`;
|
||||
};
|
||||
const escapeLiteral = (val: string) => val.replace(/'/g, "''");
|
||||
|
||||
const dbName = tab.dbName || '';
|
||||
const tableName = tab.tableName || '';
|
||||
|
||||
@@ -48,27 +55,27 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
filterConditions.forEach(cond => {
|
||||
if (cond.column && cond.value) {
|
||||
if (cond.op === 'LIKE') {
|
||||
whereParts.push(`\`${cond.column}\` LIKE '%${cond.value}%'`);
|
||||
whereParts.push(`${quoteIdent(cond.column)} LIKE '%${escapeLiteral(cond.value)}%'`);
|
||||
} else {
|
||||
whereParts.push(`\`${cond.column}\` ${cond.op} '${cond.value}'`);
|
||||
whereParts.push(`${quoteIdent(cond.column)} ${cond.op} '${escapeLiteral(cond.value)}'`);
|
||||
}
|
||||
}
|
||||
});
|
||||
const whereSQL = whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : "";
|
||||
|
||||
const countSql = `SELECT COUNT(*) as total FROM \`${tableName}\` ${whereSQL}`;
|
||||
const countSql = `SELECT COUNT(*) as total FROM ${quoteIdent(tableName)} ${whereSQL}`;
|
||||
|
||||
let sql = `SELECT * FROM \`${tableName}\` ${whereSQL}`;
|
||||
let sql = `SELECT * FROM ${quoteIdent(tableName)} ${whereSQL}`;
|
||||
if (sortInfo && sortInfo.order) {
|
||||
sql += ` ORDER BY \`${sortInfo.columnKey}\` ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
sql += ` ORDER BY ${quoteIdent(sortInfo.columnKey)} ${sortInfo.order === 'ascend' ? 'ASC' : 'DESC'}`;
|
||||
}
|
||||
const offset = (page - 1) * size;
|
||||
sql += ` LIMIT ${size} OFFSET ${offset}`;
|
||||
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const pCount = MySQLQuery(config as any, dbName, countSql);
|
||||
const pData = MySQLQuery(config as any, dbName, sql);
|
||||
const pCount = DBQuery(config as any, dbName, countSql);
|
||||
const pData = DBQuery(config as any, dbName, sql);
|
||||
|
||||
let pCols = null;
|
||||
if (pkColumns.length === 0) {
|
||||
@@ -183,4 +190,4 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default DataViewer;
|
||||
export default DataViewer;
|
||||
|
||||
@@ -5,7 +5,7 @@ import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutline
|
||||
import { format } from 'sql-formatter';
|
||||
import { TabData, ColumnDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { MySQLQuery, DBGetTables, DBGetAllColumns, MySQLGetDatabases, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import { DBQuery, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import DataGrid from './DataGrid';
|
||||
|
||||
const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
@@ -60,7 +60,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const res = await MySQLGetDatabases(config as any);
|
||||
const res = await DBGetDatabases(config as any);
|
||||
if (res.success && Array.isArray(res.data)) {
|
||||
const dbs = res.data.map((row: any) => row.Database || row.database);
|
||||
setDbList(dbs);
|
||||
@@ -252,7 +252,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const res = await MySQLQuery(config as any, currentDb, query);
|
||||
const res = await DBQuery(config as any, currentDb, query);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
addSqlLog({
|
||||
@@ -421,4 +421,4 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default QueryEditor;
|
||||
export default QueryEditor;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import React, { useEffect, useState, useMemo, useRef } from 'react';
|
||||
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge } from 'antd';
|
||||
import {
|
||||
DatabaseOutlined,
|
||||
TableOutlined,
|
||||
ConsoleSqlOutlined,
|
||||
import {
|
||||
DatabaseOutlined,
|
||||
TableOutlined,
|
||||
ConsoleSqlOutlined,
|
||||
HddOutlined,
|
||||
FolderOpenOutlined,
|
||||
FileTextOutlined,
|
||||
@@ -23,10 +23,10 @@ import {
|
||||
ReloadOutlined,
|
||||
DeleteOutlined,
|
||||
DisconnectOutlined
|
||||
} from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { SavedConnection } from '../types';
|
||||
import { MySQLGetDatabases, MySQLGetTables, MySQLShowCreateTable, ExportTable, OpenSQLFile, CreateDatabase } from '../../wailsjs/go/app/App';
|
||||
} from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { SavedConnection } from '../types';
|
||||
import { DBGetDatabases, DBGetTables, DBShowCreateTable, ExportTable, OpenSQLFile, CreateDatabase } from '../../wailsjs/go/app/App';
|
||||
|
||||
const { Search } = Input;
|
||||
|
||||
@@ -116,21 +116,21 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
});
|
||||
};
|
||||
|
||||
const loadDatabases = async (node: any) => {
|
||||
const conn = node.dataRef as SavedConnection;
|
||||
const config = {
|
||||
...conn.config,
|
||||
const loadDatabases = async (node: any) => {
|
||||
const conn = node.dataRef as SavedConnection;
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: conn.config.database || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
const res = await MySQLGetDatabases(config as any);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
let dbs = (res.data as any[]).map((row: any) => ({
|
||||
title: row.Database || row.database,
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
const res = await DBGetDatabases(config as any);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
let dbs = (res.data as any[]).map((row: any) => ({
|
||||
title: row.Database || row.database,
|
||||
key: `${conn.id}-${row.Database || row.database}`,
|
||||
icon: <DatabaseOutlined />,
|
||||
type: 'database' as const,
|
||||
@@ -150,9 +150,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
};
|
||||
|
||||
const loadTables = async (node: any) => {
|
||||
const conn = node.dataRef; // has dbName
|
||||
const dbName = conn.dbName;
|
||||
const loadTables = async (node: any) => {
|
||||
const conn = node.dataRef; // has dbName
|
||||
const dbName = conn.dbName;
|
||||
const key = node.key;
|
||||
|
||||
const dbQueries = savedQueries.filter(q => q.connectionId === conn.id && q.dbName === dbName);
|
||||
@@ -178,13 +178,13 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: conn.config.database || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
const res = await MySQLGetTables(config as any, conn.dbName);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [key as string]: 'success' }));
|
||||
const tables = (res.data as any[]).map((row: any) => {
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
const res = await DBGetTables(config as any, conn.dbName);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [key as string]: 'success' }));
|
||||
const tables = (res.data as any[]).map((row: any) => {
|
||||
const tableName = Object.values(row)[0] as string;
|
||||
return {
|
||||
title: tableName,
|
||||
@@ -345,13 +345,13 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyStructure = async (node: any) => {
|
||||
const { config, dbName, tableName } = node.dataRef;
|
||||
const res = await MySQLShowCreateTable({
|
||||
...config,
|
||||
port: Number(config.port),
|
||||
password: config.password || "",
|
||||
database: config.database || "",
|
||||
const handleCopyStructure = async (node: any) => {
|
||||
const { config, dbName, tableName } = node.dataRef;
|
||||
const res = await DBShowCreateTable({
|
||||
...config,
|
||||
port: Number(config.port),
|
||||
password: config.password || "",
|
||||
database: config.database || "",
|
||||
useSSH: config.useSSH || false,
|
||||
ssh: config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
} as any, dbName, tableName);
|
||||
|
||||
@@ -7,7 +7,7 @@ import { CSS } from '@dnd-kit/utilities';
|
||||
import { Resizable } from 'react-resizable';
|
||||
import { TabData, ColumnDefinition, IndexDefinition, ForeignKeyDefinition, TriggerDefinition } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBGetColumns, DBGetIndexes, MySQLQuery, DBGetForeignKeys, DBGetTriggers, DBShowCreateTable } from '../../wailsjs/go/app/App';
|
||||
import { DBGetColumns, DBGetIndexes, DBQuery, DBGetForeignKeys, DBGetTriggers, DBShowCreateTable } from '../../wailsjs/go/app/App';
|
||||
|
||||
// Need styles for react-resizable
|
||||
import 'react-resizable/css/styles.css';
|
||||
@@ -518,15 +518,15 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleExecuteSave = async () => {
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) return;
|
||||
const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } };
|
||||
const res = await MySQLQuery(config as any, tab.dbName || '', previewSql);
|
||||
if (res.success) {
|
||||
message.success(isNewTable ? "表创建成功!" : "表结构修改成功!");
|
||||
setIsPreviewOpen(false);
|
||||
if (!isNewTable) {
|
||||
const handleExecuteSave = async () => {
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) return;
|
||||
const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } };
|
||||
const res = await DBQuery(config as any, tab.dbName || '', previewSql);
|
||||
if (res.success) {
|
||||
message.success(isNewTable ? "表创建成功!" : "表结构修改成功!");
|
||||
setIsPreviewOpen(false);
|
||||
if (!isNewTable) {
|
||||
fetchData();
|
||||
} else {
|
||||
// TODO: Close tab or reload sidebar?
|
||||
@@ -730,4 +730,4 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default TableDesigner;
|
||||
export default TableDesigner;
|
||||
|
||||
@@ -79,6 +79,7 @@ export namespace connection {
|
||||
ssh: SSHConfig;
|
||||
driver?: string;
|
||||
dsn?: string;
|
||||
timeout?: number;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new ConnectionConfig(source);
|
||||
@@ -96,6 +97,7 @@ export namespace connection {
|
||||
this.ssh = this.convertValues(source["ssh"], SSHConfig);
|
||||
this.driver = source["driver"];
|
||||
this.dsn = source["dsn"];
|
||||
this.timeout = source["timeout"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
|
||||
@@ -2,11 +2,18 @@ package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
)
|
||||
|
||||
// App struct
|
||||
@@ -27,55 +34,149 @@ func NewApp() *App {
|
||||
// so we can call the runtime methods
|
||||
func (a *App) Startup(ctx context.Context) {
|
||||
a.ctx = ctx
|
||||
logger.Init()
|
||||
logger.Infof("应用启动完成")
|
||||
}
|
||||
|
||||
// Shutdown is called when the app terminates
|
||||
func (a *App) Shutdown(ctx context.Context) {
|
||||
logger.Infof("应用开始关闭,准备释放资源")
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
for _, dbInst := range a.dbCache {
|
||||
dbInst.Close()
|
||||
if err := dbInst.Close(); err != nil {
|
||||
logger.Error(err, "关闭数据库连接失败")
|
||||
}
|
||||
}
|
||||
logger.Infof("资源释放完成,应用已关闭")
|
||||
logger.Close()
|
||||
}
|
||||
|
||||
// Helper: Generate a unique key for the connection config
|
||||
func getCacheKey(config connection.ConnectionConfig) string {
|
||||
sshPart := ""
|
||||
if config.UseSSH {
|
||||
sshPart = fmt.Sprintf("|ssh:%s@%s:%d|%s", config.SSH.User, config.SSH.Host, config.SSH.Port, config.SSH.KeyPath)
|
||||
// We don't include SSH password in key string to avoid log exposure if key is logged,
|
||||
// but for cache uniqueness it is critical.
|
||||
// Let's include a hash or just the value if we assume internal use.
|
||||
// Including value for correctness.
|
||||
sshPart += "|" + config.SSH.Password
|
||||
if !config.UseSSH {
|
||||
config.SSH = connection.SSHConfig{}
|
||||
}
|
||||
return fmt.Sprintf("%s|%s:%s@%s:%d|%s%s", config.Type, config.User, config.Password, config.Host, config.Port, config.Database, sshPart)
|
||||
// 保持与驱动默认一致,避免同一连接被重复缓存
|
||||
if config.Type == "postgres" && config.Database == "" {
|
||||
config.Database = "postgres"
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(config)
|
||||
sum := sha256.Sum256(b)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func wrapConnectError(config connection.ConnectionConfig, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) {
|
||||
dbName := config.Database
|
||||
if dbName == "" {
|
||||
dbName = "(default)"
|
||||
}
|
||||
err = fmt.Errorf("数据库连接超时:%s %s:%d/%s:%w", config.Type, config.Host, config.Port, dbName, err)
|
||||
}
|
||||
|
||||
return withLogHint{err: err, logPath: logger.Path()}
|
||||
}
|
||||
|
||||
type withLogHint struct {
|
||||
err error
|
||||
logPath string
|
||||
}
|
||||
|
||||
func (e withLogHint) Error() string {
|
||||
if strings.TrimSpace(e.logPath) == "" {
|
||||
return e.err.Error()
|
||||
}
|
||||
return fmt.Sprintf("%s(详细日志:%s)", e.err.Error(), e.logPath)
|
||||
}
|
||||
|
||||
func (e withLogHint) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func formatConnSummary(config connection.ConnectionConfig) string {
|
||||
timeoutSeconds := config.Timeout
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 30
|
||||
}
|
||||
|
||||
dbName := config.Database
|
||||
if strings.TrimSpace(dbName) == "" {
|
||||
dbName = "(default)"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("类型=%s 地址=%s:%d 数据库=%s 用户=%s 超时=%ds",
|
||||
config.Type, config.Host, config.Port, dbName, config.User, timeoutSeconds))
|
||||
|
||||
if config.UseSSH {
|
||||
b.WriteString(fmt.Sprintf(" SSH=%s:%d 用户=%s", config.SSH.Host, config.SSH.Port, config.SSH.User))
|
||||
}
|
||||
|
||||
if config.Type == "custom" {
|
||||
driver := strings.TrimSpace(config.Driver)
|
||||
if driver == "" {
|
||||
driver = "(未配置)"
|
||||
}
|
||||
dsnState := "未配置"
|
||||
if strings.TrimSpace(config.DSN) != "" {
|
||||
dsnState = fmt.Sprintf("已配置(长度=%d)", len(config.DSN))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" 驱动=%s DSN=%s", driver, dsnState))
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Helper: Get or create a database connection
|
||||
func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, error) {
|
||||
key := getCacheKey(config)
|
||||
shortKey := key
|
||||
if len(shortKey) > 12 {
|
||||
shortKey = shortKey[:12]
|
||||
}
|
||||
if config.UseSSH && config.Type != "mysql" {
|
||||
logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config))
|
||||
}
|
||||
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if dbInst, ok := a.dbCache[key]; ok {
|
||||
logger.Infof("命中连接缓存,开始检测可用性:缓存Key=%s", shortKey)
|
||||
if err := dbInst.Ping(); err == nil {
|
||||
logger.Infof("缓存连接可用:缓存Key=%s", shortKey)
|
||||
return dbInst, nil
|
||||
} else {
|
||||
logger.Error(err, "缓存连接不可用,准备重建:缓存Key=%s", shortKey)
|
||||
}
|
||||
if err := dbInst.Close(); err != nil {
|
||||
logger.Error(err, "关闭失效缓存连接失败:缓存Key=%s", shortKey)
|
||||
}
|
||||
dbInst.Close()
|
||||
delete(a.dbCache, key)
|
||||
}
|
||||
|
||||
logger.Infof("创建数据库驱动实例:类型=%s 缓存Key=%s", config.Type, shortKey)
|
||||
dbInst, err := db.NewDatabase(config.Type)
|
||||
if err != nil {
|
||||
logger.Error(err, "创建数据库驱动实例失败:类型=%s 缓存Key=%s", config.Type, shortKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := dbInst.Connect(config); err != nil {
|
||||
return nil, err
|
||||
wrapped := wrapConnectError(config, err)
|
||||
logger.Error(wrapped, "建立数据库连接失败:%s 缓存Key=%s", formatConnSummary(config), shortKey)
|
||||
return nil, wrapped
|
||||
}
|
||||
|
||||
a.dbCache[key] = dbInst
|
||||
logger.Infof("数据库连接成功并写入缓存:%s 缓存Key=%s", formatConnSummary(config), shortKey)
|
||||
return dbInst, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
)
|
||||
|
||||
// Generic DB Methods
|
||||
@@ -13,18 +14,22 @@ func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResu
|
||||
// getDatabase checks cache and Pings. If valid, reuses. If not, connects.
|
||||
_, err := a.getDatabase(config)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBConnect 连接失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
logger.Infof("DBConnect 连接成功:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: true, Message: "连接成功"}
|
||||
}
|
||||
|
||||
func (a *App) TestConnection(config connection.ConnectionConfig) connection.QueryResult {
|
||||
_, err := a.getDatabase(config)
|
||||
if err != nil {
|
||||
logger.Error(err, "TestConnection 连接测试失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
logger.Infof("TestConnection 连接测试成功:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: true, Message: "连接成功"}
|
||||
}
|
||||
|
||||
@@ -37,9 +42,11 @@ func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("CREATE DATABASE `%%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", dbName)
|
||||
escapedDbName := strings.ReplaceAll(dbName, "`", "``")
|
||||
query := fmt.Sprintf("CREATE DATABASE `%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", escapedDbName)
|
||||
if runConfig.Type == "postgres" {
|
||||
query = fmt.Sprintf("CREATE DATABASE \"%%s\"", dbName)
|
||||
escapedDbName = strings.ReplaceAll(dbName, `"`, `""`)
|
||||
query = fmt.Sprintf("CREATE DATABASE \"%s\"", escapedDbName)
|
||||
}
|
||||
|
||||
_, err = dbInst.Exec(query)
|
||||
@@ -83,6 +90,7 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
@@ -90,26 +98,39 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
|
||||
if strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain") {
|
||||
data, columns, err := dbInst.Query(query)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns}
|
||||
} else {
|
||||
affected, err := dbInst.Exec(query)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}}
|
||||
}
|
||||
}
|
||||
|
||||
func sqlSnippet(query string) string {
|
||||
q := strings.TrimSpace(query)
|
||||
const max = 200
|
||||
if len(q) <= max {
|
||||
return q
|
||||
}
|
||||
return q[:max] + "..."
|
||||
}
|
||||
|
||||
func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
|
||||
dbInst, err := a.getDatabase(config)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
dbs, err := dbInst.GetDatabases()
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(config))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
@@ -129,11 +150,13 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetTables 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
tables, err := dbInst.GetTables(dbName)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetTables 获取表列表失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
@@ -153,11 +176,13 @@ func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName strin
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBShowCreateTable 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
sqlStr, err := dbInst.GetCreateStatement(dbName, tableName)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBShowCreateTable 获取建表语句失败:%s 表=%s", formatConnSummary(runConfig), tableName)
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
@@ -257,4 +282,4 @@ func (a *App) DBGetAllColumns(config connection.ConnectionConfig, dbName string)
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Data: cols}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
|
||||
"github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
@@ -181,7 +182,7 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
|
||||
_, err := dbInst.Exec(query)
|
||||
if err != nil {
|
||||
errCount++
|
||||
fmt.Println("Import Error:", err)
|
||||
logger.Error(err, "导入数据失败:表=%s", tableName)
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
@@ -404,4 +405,4 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
|
||||
}
|
||||
|
||||
return connection.QueryResult{Success: true, Message: "Export successful"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type ConnectionConfig struct {
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
}
|
||||
|
||||
// QueryResult is the standard response format for Wails methods
|
||||
|
||||
@@ -11,8 +11,9 @@ import (
|
||||
)
|
||||
|
||||
type CustomDB struct {
|
||||
conn *sql.DB
|
||||
driver string
|
||||
conn *sql.DB
|
||||
driver string
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (c *CustomDB) Connect(config connection.ConnectionConfig) error {
|
||||
@@ -25,11 +26,15 @@ func (c *CustomDB) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
db, err := sql.Open(config.Driver, config.DSN)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
c.conn = db
|
||||
c.driver = config.Driver
|
||||
return c.Ping()
|
||||
c.pingTimeout = getConnectTimeout(config)
|
||||
if err := c.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CustomDB) Close() error {
|
||||
@@ -43,7 +48,11 @@ func (c *CustomDB) Ping() error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(5 * time.Second)
|
||||
timeout := c.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return c.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ package db
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -14,14 +17,15 @@ import (
|
||||
)
|
||||
|
||||
type DamengDB struct {
|
||||
conn *sql.DB
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// dm://user:password@host:port?schema=...
|
||||
// or dm://user:password@host:port
|
||||
|
||||
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
|
||||
if config.UseSSH {
|
||||
// SSH logic similar to others, assumes port forwarding
|
||||
_, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
@@ -32,21 +36,36 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("dm://%s:%s@%s", config.User, config.Password, address)
|
||||
escapedPassword := url.PathEscape(config.Password)
|
||||
q := url.Values{}
|
||||
if config.Database != "" {
|
||||
dsn += fmt.Sprintf("?schema=%s", config.Database)
|
||||
q.Set("schema", config.Database)
|
||||
}
|
||||
return dsn
|
||||
if escapedPassword != config.Password {
|
||||
// 达梦驱动要求:密码包含特殊字符时,password 需 PathEscape,并添加 escapeProcess=true 让驱动解码。
|
||||
q.Set("escapeProcess", "true")
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("dm://%s:%s@%s", config.User, escapedPassword, address)
|
||||
encoded := q.Encode()
|
||||
if encoded == "" {
|
||||
return dsn
|
||||
}
|
||||
return dsn + "?" + encoded
|
||||
}
|
||||
|
||||
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := d.getDSN(config)
|
||||
db, err := sql.Open("dm", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
d.conn = db
|
||||
return d.Ping()
|
||||
d.pingTimeout = getConnectTimeout(config)
|
||||
if err := d.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DamengDB) Close() error {
|
||||
@@ -60,7 +79,11 @@ func (d *DamengDB) Ping() error {
|
||||
if d.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(5 * time.Second)
|
||||
timeout := d.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return d.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
97
internal/db/dsn_test.go
Normal file
97
internal/db/dsn_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestPostgresDSN_EscapesPassword(t *testing.T) {
|
||||
p := &PostgresDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5432,
|
||||
User: "user",
|
||||
Password: "p@ss:wo/rd",
|
||||
Database: "db",
|
||||
}
|
||||
|
||||
dsn := p.getDSN(cfg)
|
||||
if strings.Contains(dsn, cfg.Password) {
|
||||
t.Fatalf("dsn 包含原始密码:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "p%40ss%3Awo%2Frd") {
|
||||
t.Fatalf("dsn 未正确转义密码:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "sslmode=disable") {
|
||||
t.Fatalf("dsn 缺少 sslmode 参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleDSN_EscapesUserAndPassword(t *testing.T) {
|
||||
o := &OracleDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "oracle",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1521,
|
||||
User: "u@ser",
|
||||
Password: "p@ss:wo/rd",
|
||||
Database: "svc/name",
|
||||
}
|
||||
|
||||
dsn := o.getDSN(cfg)
|
||||
if strings.Contains(dsn, cfg.Password) {
|
||||
t.Fatalf("dsn 包含原始密码:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "u%40ser") || !strings.Contains(dsn, "p%40ss%3Awo%2Frd") {
|
||||
t.Fatalf("dsn 未正确转义 user/password:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "/svc%2Fname") {
|
||||
t.Fatalf("dsn 未正确转义 service:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDamengDSN_EscapesPasswordAndEnablesEscapeProcess(t *testing.T) {
|
||||
d := &DamengDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "dameng",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5236,
|
||||
User: "SYSDBA",
|
||||
Password: "p@ss:wo/rd",
|
||||
Database: "DBName",
|
||||
}
|
||||
|
||||
dsn := d.getDSN(cfg)
|
||||
if strings.Contains(dsn, cfg.Password) {
|
||||
t.Fatalf("dsn 包含原始密码:%s", dsn)
|
||||
}
|
||||
if strings.Contains(dsn, "wo/rd") || !strings.Contains(dsn, "wo%2Frd") {
|
||||
t.Fatalf("dsn 未按达梦驱动要求转义密码(至少应转义 '/'):%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "escapeProcess=true") {
|
||||
t.Fatalf("dsn 缺少 escapeProcess=true:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "schema=DBName") {
|
||||
t.Fatalf("dsn 缺少 schema 参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) {
|
||||
k := &KingbaseDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "kingbase",
|
||||
Host: "127.0.0.1",
|
||||
Port: 54321,
|
||||
User: "system",
|
||||
Password: "p@ss word",
|
||||
Database: "TEST",
|
||||
}
|
||||
|
||||
dsn := k.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "password='p@ss word'") {
|
||||
t.Fatalf("dsn 未对包含空格的密码进行引号包裹:%s", dsn)
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,40 @@ import (
|
||||
)
|
||||
|
||||
type KingbaseDB struct {
|
||||
conn *sql.DB
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func quoteConnValue(v string) string {
|
||||
if v == "" {
|
||||
return "''"
|
||||
}
|
||||
|
||||
needsQuote := false
|
||||
for _, r := range v {
|
||||
switch r {
|
||||
case ' ', '\t', '\n', '\r', '\v', '\f', '\'', '\\':
|
||||
needsQuote = true
|
||||
}
|
||||
if needsQuote {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !needsQuote {
|
||||
return v
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(v) + 2)
|
||||
b.WriteByte('\'')
|
||||
for _, r := range v {
|
||||
if r == '\\' || r == '\'' {
|
||||
b.WriteByte('\\')
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
b.WriteByte('\'')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
@@ -39,8 +72,14 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
// Construct DSN
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
address, port, config.User, config.Password, config.Database)
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
|
||||
quoteConnValue(address),
|
||||
port,
|
||||
quoteConnValue(config.User),
|
||||
quoteConnValue(config.Password),
|
||||
quoteConnValue(config.Database),
|
||||
getConnectTimeoutSeconds(config),
|
||||
)
|
||||
|
||||
return dsn
|
||||
}
|
||||
@@ -50,10 +89,14 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
// Open using "kingbase" driver
|
||||
db, err := sql.Open("kingbase", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
k.conn = db
|
||||
return k.Ping()
|
||||
k.pingTimeout = getConnectTimeout(config)
|
||||
if err := k.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Close() error {
|
||||
@@ -67,7 +110,11 @@ func (k *KingbaseDB) Ping() error {
|
||||
if k.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(5 * time.Second)
|
||||
timeout := k.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return k.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
@@ -14,7 +15,8 @@ import (
|
||||
)
|
||||
|
||||
type MySQLDB struct {
|
||||
conn *sql.DB
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
|
||||
@@ -27,23 +29,31 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
|
||||
if err == nil {
|
||||
protocol = netName
|
||||
address = fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
} else {
|
||||
logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
config.User, config.Password, protocol, address, database)
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
|
||||
config.User, config.Password, protocol, address, database, timeout)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := m.getDSN(config)
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
m.conn = db
|
||||
m.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
// Force verification
|
||||
return m.Ping()
|
||||
if err := m.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQLDB) Close() error {
|
||||
@@ -57,7 +67,11 @@ func (m *MySQLDB) Ping() error {
|
||||
if m.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(5 * time.Second)
|
||||
timeout := m.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return m.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ package db
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -14,7 +17,8 @@ import (
|
||||
)
|
||||
|
||||
type OracleDB struct {
|
||||
conn *sql.DB
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
@@ -24,7 +28,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
database = config.User // Default to user service/schema if empty?
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
if config.UseSSH {
|
||||
_, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err == nil {
|
||||
@@ -47,19 +50,28 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
}
|
||||
|
||||
// go-ora url structure: oracle://user:password@address:port/service_name
|
||||
return fmt.Sprintf("oracle://%s:%s@%s/%s",
|
||||
config.User, config.Password, address, database)
|
||||
u := &url.URL{
|
||||
Scheme: "oracle",
|
||||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
Path: "/" + database,
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
u.RawPath = "/" + url.PathEscape(database)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := o.getDSN(config)
|
||||
db, err := sql.Open("oracle", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
o.conn = db
|
||||
return o.Ping()
|
||||
o.pingTimeout = getConnectTimeout(config)
|
||||
if err := o.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OracleDB) Close() error {
|
||||
@@ -73,7 +85,11 @@ func (o *OracleDB) Ping() error {
|
||||
if o.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(5 * time.Second)
|
||||
timeout := o.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return o.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ package db
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -12,37 +15,45 @@ import (
|
||||
)
|
||||
|
||||
type PostgresDB struct {
|
||||
conn *sql.DB
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||||
host := config.Host
|
||||
port := config.Port
|
||||
// SSH placeholder kept from original
|
||||
if config.UseSSH {
|
||||
// Logic to be implemented
|
||||
}
|
||||
|
||||
dbname := config.Database
|
||||
if dbname == "" {
|
||||
dbname = "postgres" // Default DB
|
||||
}
|
||||
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
|
||||
config.User, config.Password, host, port, dbname)
|
||||
u := &url.URL{
|
||||
Scheme: "postgres",
|
||||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
Path: "/" + dbname,
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", "disable")
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := p.getDSN(config)
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
p.conn = db
|
||||
|
||||
p.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
// Force verification
|
||||
return p.Ping()
|
||||
if err := p.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Close() error {
|
||||
@@ -56,7 +67,11 @@ func (p *PostgresDB) Ping() error {
|
||||
if p.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(5 * time.Second)
|
||||
timeout := p.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return p.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
@@ -12,19 +12,24 @@ import (
|
||||
)
|
||||
|
||||
type SQLiteDB struct {
|
||||
conn *sql.DB
|
||||
conn *sql.DB
|
||||
pingTimeout time.Duration
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn := config.Host
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
s.conn = db
|
||||
s.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
// Force verification
|
||||
return s.Ping()
|
||||
if err := s.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteDB) Close() error {
|
||||
@@ -38,7 +43,11 @@ func (s *SQLiteDB) Ping() error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("connection not open")
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(5 * time.Second)
|
||||
timeout := s.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return s.conn.PingContext(ctx)
|
||||
}
|
||||
|
||||
22
internal/db/timeout.go
Normal file
22
internal/db/timeout.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
const defaultConnectTimeoutSeconds = 30
|
||||
|
||||
func getConnectTimeoutSeconds(config connection.ConnectionConfig) int {
|
||||
timeoutSeconds := config.Timeout
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = defaultConnectTimeoutSeconds
|
||||
}
|
||||
return timeoutSeconds
|
||||
}
|
||||
|
||||
func getConnectTimeout(config connection.ConnectionConfig) time.Duration {
|
||||
return time.Duration(getConnectTimeoutSeconds(config)) * time.Second
|
||||
}
|
||||
|
||||
197
internal/logger/logger.go
Normal file
197
internal/logger/logger.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
envLogDir = "GONAVI_LOG_DIR"
|
||||
appDirName = "GoNavi"
|
||||
|
||||
logFileName = "gonavi.log"
|
||||
logRotateMaxBytes = 10 * 1024 * 1024 // 10MB
|
||||
logRotateMaxBackups = 10
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
logMu sync.Mutex
|
||||
logInst *log.Logger
|
||||
logFile *os.File
|
||||
logPath string
|
||||
)
|
||||
|
||||
func Init() {
|
||||
once.Do(func() {
|
||||
path, out := initOutput()
|
||||
logMu.Lock()
|
||||
defer logMu.Unlock()
|
||||
logPath = path
|
||||
logInst = log.New(out, "", log.Ldate|log.Ltime|log.Lmicroseconds)
|
||||
logInst.Printf("[信息] 日志初始化完成,日志文件:%s", logPath)
|
||||
})
|
||||
}
|
||||
|
||||
func Path() string {
|
||||
Init()
|
||||
logMu.Lock()
|
||||
defer logMu.Unlock()
|
||||
return logPath
|
||||
}
|
||||
|
||||
func Close() {
|
||||
Init()
|
||||
logMu.Lock()
|
||||
defer logMu.Unlock()
|
||||
if logInst != nil {
|
||||
logInst.SetOutput(os.Stderr)
|
||||
}
|
||||
if logFile != nil {
|
||||
_ = logFile.Close()
|
||||
logFile = nil
|
||||
}
|
||||
}
|
||||
|
||||
func Infof(format string, args ...any) {
|
||||
printf("信息", format, args...)
|
||||
}
|
||||
|
||||
func Warnf(format string, args ...any) {
|
||||
printf("警告", format, args...)
|
||||
}
|
||||
|
||||
func Errorf(format string, args ...any) {
|
||||
printf("错误", format, args...)
|
||||
}
|
||||
|
||||
func Error(err error, format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if err == nil {
|
||||
Errorf("%s", msg)
|
||||
return
|
||||
}
|
||||
Errorf("%s;错误链:%s", msg, ErrorChain(err))
|
||||
}
|
||||
|
||||
func ErrorChain(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
seen := map[string]struct{}{}
|
||||
cur := err
|
||||
truncated := false
|
||||
for i := 0; cur != nil && i < 20; i++ {
|
||||
s := cur.Error()
|
||||
if _, ok := seen[s]; !ok {
|
||||
seen[s] = struct{}{}
|
||||
parts = append(parts, s)
|
||||
}
|
||||
cur = errors.Unwrap(cur)
|
||||
}
|
||||
if cur != nil {
|
||||
truncated = true
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return err.Error()
|
||||
}
|
||||
if truncated {
|
||||
parts = append(parts, "(错误链过长,已截断)")
|
||||
}
|
||||
return strings.Join(parts, " -> ")
|
||||
}
|
||||
|
||||
func printf(level string, format string, args ...any) {
|
||||
Init()
|
||||
logMu.Lock()
|
||||
inst := logInst
|
||||
logMu.Unlock()
|
||||
if inst == nil {
|
||||
return
|
||||
}
|
||||
inst.Printf("[%s] %s", level, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func initOutput() (string, io.Writer) {
|
||||
dir := strings.TrimSpace(os.Getenv(envLogDir))
|
||||
if dir == "" {
|
||||
base, err := os.UserConfigDir()
|
||||
if err != nil || strings.TrimSpace(base) == "" {
|
||||
base = os.TempDir()
|
||||
}
|
||||
dir = filepath.Join(base, appDirName, "logs")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return filepath.Join(dir, logFileName), os.Stderr
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, logFileName)
|
||||
rotateIfNeeded(path, dir)
|
||||
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return path, os.Stderr
|
||||
}
|
||||
logFile = f
|
||||
return path, f
|
||||
}
|
||||
|
||||
func rotateIfNeeded(path, dir string) {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil || fi.IsDir() {
|
||||
return
|
||||
}
|
||||
if fi.Size() < logRotateMaxBytes {
|
||||
return
|
||||
}
|
||||
|
||||
ts := time.Now().Format("20060102-150405")
|
||||
rotated := filepath.Join(dir, fmt.Sprintf("gonavi-%s.log", ts))
|
||||
if err := os.Rename(path, rotated); err != nil {
|
||||
return
|
||||
}
|
||||
cleanupOldLogs(dir)
|
||||
}
|
||||
|
||||
func cleanupOldLogs(dir string) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
type item struct {
|
||||
name string
|
||||
path string
|
||||
}
|
||||
var logs []item
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasPrefix(name, "gonavi-") || !strings.HasSuffix(name, ".log") {
|
||||
continue
|
||||
}
|
||||
logs = append(logs, item{name: name, path: filepath.Join(dir, name)})
|
||||
}
|
||||
|
||||
sort.Slice(logs, func(i, j int) bool { return logs[i].name > logs[j].name })
|
||||
if len(logs) <= logRotateMaxBackups {
|
||||
return
|
||||
}
|
||||
for _, it := range logs[logRotateMaxBackups:] {
|
||||
_ = os.Remove(it.path)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"golang.org/x/crypto/ssh"
|
||||
@@ -19,18 +20,49 @@ type ViaSSHDialer struct {
|
||||
}
|
||||
|
||||
func (d *ViaSSHDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return d.sshClient.Dial("tcp", addr)
|
||||
return dialContext(ctx, d.sshClient, "tcp", addr)
|
||||
}
|
||||
|
||||
func dialContext(ctx context.Context, client *ssh.Client, network, addr string) (net.Conn, error) {
|
||||
type result struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
c, err := client.Dial(network, addr)
|
||||
ch <- result{conn: c, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
go func() {
|
||||
r := <-ch
|
||||
if r.conn != nil {
|
||||
_ = r.conn.Close()
|
||||
}
|
||||
}()
|
||||
return nil, ctx.Err()
|
||||
case r := <-ch:
|
||||
return r.conn, r.err
|
||||
}
|
||||
}
|
||||
|
||||
// connectSSH establishes an SSH connection and returns a Dialer
|
||||
func connectSSH(config connection.SSHConfig) (*ssh.Client, error) {
|
||||
logger.Infof("开始建立 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
authMethods := []ssh.AuthMethod{}
|
||||
|
||||
if config.KeyPath != "" {
|
||||
key, err := os.ReadFile(config.KeyPath)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
logger.Warnf("读取 SSH 私钥失败:路径=%s,原因:%v", config.KeyPath, err)
|
||||
} else {
|
||||
signer, err := ssh.ParsePrivateKey(key)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
logger.Warnf("解析 SSH 私钥失败:路径=%s,原因:%v", config.KeyPath, err)
|
||||
} else {
|
||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||
}
|
||||
}
|
||||
@@ -39,6 +71,9 @@ func connectSSH(config connection.SSHConfig) (*ssh.Client, error) {
|
||||
if config.Password != "" {
|
||||
authMethods = append(authMethods, ssh.Password(config.Password))
|
||||
}
|
||||
if len(authMethods) == 0 {
|
||||
logger.Warnf("SSH 未配置认证方式(密码或私钥)")
|
||||
}
|
||||
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: config.User,
|
||||
@@ -48,7 +83,13 @@ func connectSSH(config connection.SSHConfig) (*ssh.Client, error) {
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
return ssh.Dial("tcp", addr, sshConfig)
|
||||
client, err := ssh.Dial("tcp", addr, sshConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "SSH 连接建立失败:地址=%s 用户=%s", addr, config.User)
|
||||
return nil, err
|
||||
}
|
||||
logger.Infof("SSH 连接建立成功:地址=%s 用户=%s", addr, config.User)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// RegisterSSHNetwork registers a unique network name for a specific SSH tunnel
|
||||
@@ -61,9 +102,10 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
|
||||
|
||||
// Generate unique network name
|
||||
netName := fmt.Sprintf("ssh_%s_%d", sshConfig.Host, time.Now().UnixNano())
|
||||
logger.Infof("注册 SSH 网络:%s(地址=%s:%d 用户=%s)", netName, sshConfig.Host, sshConfig.Port, sshConfig.User)
|
||||
|
||||
mysql.RegisterDialContext(netName, func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return client.Dial("tcp", addr)
|
||||
return dialContext(ctx, client, "tcp", addr)
|
||||
})
|
||||
|
||||
return netName, nil
|
||||
|
||||
@@ -3,7 +3,9 @@ package sync
|
||||
import (
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SyncConfig defines the parameters for a synchronization task
|
||||
@@ -35,9 +37,11 @@ func NewSyncEngine() *SyncEngine {
|
||||
// CompareAndSync performs the synchronization
|
||||
func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
result := SyncResult{Success: true, Logs: []string{}}
|
||||
logger.Infof("开始数据同步:源=%s 目标=%s 表数量=%d", formatConnSummaryForSync(config.SourceConfig), formatConnSummaryForSync(config.TargetConfig), len(config.Tables))
|
||||
|
||||
sourceDB, err := db.NewDatabase(config.SourceConfig.Type)
|
||||
if err != nil {
|
||||
logger.Error(err, "初始化源数据库驱动失败:类型=%s", config.SourceConfig.Type)
|
||||
return s.fail(result, "初始化源数据库驱动失败: "+err.Error())
|
||||
}
|
||||
if config.SourceConfig.Type == "custom" {
|
||||
@@ -46,12 +50,14 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
|
||||
targetDB, err := db.NewDatabase(config.TargetConfig.Type)
|
||||
if err != nil {
|
||||
logger.Error(err, "初始化目标数据库驱动失败:类型=%s", config.TargetConfig.Type)
|
||||
return s.fail(result, "初始化目标数据库驱动失败: "+err.Error())
|
||||
}
|
||||
|
||||
// Connect Source
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("正在连接源数据库: %s...", config.SourceConfig.Host))
|
||||
if err := sourceDB.Connect(config.SourceConfig); err != nil {
|
||||
logger.Error(err, "源数据库连接失败:%s", formatConnSummaryForSync(config.SourceConfig))
|
||||
return s.fail(result, "源数据库连接失败: "+err.Error())
|
||||
}
|
||||
defer sourceDB.Close()
|
||||
@@ -59,6 +65,7 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
// Connect Target
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("正在连接目标数据库: %s...", config.TargetConfig.Host))
|
||||
if err := targetDB.Connect(config.TargetConfig); err != nil {
|
||||
logger.Error(err, "目标数据库连接失败:%s", formatConnSummaryForSync(config.TargetConfig))
|
||||
return s.fail(result, "目标数据库连接失败: "+err.Error())
|
||||
}
|
||||
defer targetDB.Close()
|
||||
@@ -70,6 +77,7 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
// 1. Get Columns & PKs (Naive approach: assume same schema)
|
||||
cols, err := sourceDB.GetColumns(config.SourceConfig.Database, tableName)
|
||||
if err != nil {
|
||||
logger.Error(err, "获取源表列信息失败:表=%s", tableName)
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("获取表 %s 的列信息失败: %v", tableName, err))
|
||||
continue
|
||||
}
|
||||
@@ -91,12 +99,14 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
// TODO: Implement paging/streaming
|
||||
sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", tableName))
|
||||
if err != nil {
|
||||
logger.Error(err, "读取源表失败:表=%s", tableName)
|
||||
result.Logs = append(result.Logs, fmt.Sprintf("读取源表 %s 失败: %v", tableName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", tableName))
|
||||
if err != nil {
|
||||
logger.Error(err, "读取目标表失败:表=%s", tableName)
|
||||
// Table might not exist in target?
|
||||
// Check if error is "table not found" -> Try to Create?
|
||||
// For now, assume table exists.
|
||||
@@ -171,6 +181,21 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult {
|
||||
return result
|
||||
}
|
||||
|
||||
func formatConnSummaryForSync(config connection.ConnectionConfig) string {
|
||||
timeoutSeconds := config.Timeout
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 30
|
||||
}
|
||||
|
||||
dbName := strings.TrimSpace(config.Database)
|
||||
if dbName == "" {
|
||||
dbName = "(default)"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("类型=%s 地址=%s:%d 数据库=%s 用户=%s 超时=%ds",
|
||||
config.Type, config.Host, config.Port, dbName, config.User, timeoutSeconds)
|
||||
}
|
||||
|
||||
func (s *SyncEngine) fail(res SyncResult, msg string) SyncResult {
|
||||
res.Success = false
|
||||
res.Message = msg
|
||||
|
||||
Reference in New Issue
Block a user