From b707c742030cdaf6f3b0e280e8f00eddb8f1bcb9 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Fri, 15 May 2026 22:04:20 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(connection):=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E8=BF=9E=E6=8E=A5=20SSL=20=E8=AF=81=E4=B9=A6=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 CA 证书、客户端证书和私钥路径配置 - 为 MySQL、PostgreSQL、ClickHouse、MongoDB、Redis 等连接接入 TLS 证书 - 修正 SSL 模式下证书校验、明文回退和 DER 证书兼容问题 - 补充证书路径保存、RPC 传递和 DSN 生成回归测试 Refs #463 --- frontend/src/components/ConnectionModal.tsx | 369 +++++++++++++++--- frontend/src/store.test.ts | 28 ++ frontend/src/store.ts | 1 + frontend/src/types.ts | 1 + .../src/utils/connectionRpcConfig.test.ts | 21 + frontend/wailsjs/go/app/App.d.ts | 4 + frontend/wailsjs/go/app/App.js | 8 + frontend/wailsjs/go/models.ts | 2 + internal/app/methods_file.go | 55 +++ internal/connection/types.go | 1 + internal/db/clickhouse_impl.go | 19 +- internal/db/diros_impl.go | 2 +- internal/db/dsn_test.go | 170 +++++++- internal/db/highgo_impl.go | 1 + internal/db/kingbase_impl.go | 3 +- internal/db/mariadb_impl.go | 2 +- internal/db/mongodb_impl.go | 16 +- internal/db/mongodb_impl_v1.go | 16 +- internal/db/mysql_impl.go | 61 ++- internal/db/oceanbase_impl.go | 2 +- internal/db/postgres_impl.go | 1 + internal/db/sqlserver_impl.go | 3 + internal/db/ssl_mode.go | 65 ++- internal/db/starrocks_impl.go | 2 +- internal/db/vastbase_impl.go | 1 + internal/redis/redis_impl.go | 10 +- internal/redis/ssl_mode.go | 28 +- internal/tlsconfig/tlsconfig.go | 62 +++ internal/tlsconfig/tlsconfig_test.go | 126 ++++++ 29 files changed, 965 insertions(+), 115 deletions(-) create mode 100644 internal/tlsconfig/tlsconfig.go create mode 100644 internal/tlsconfig/tlsconfig_test.go diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 61ef57e..0f892c9 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -92,6 +92,7 @@ import { TestConnection, RedisConnect, SelectDatabaseFile, + SelectCertificateFile, SelectSSHKeyFile, TestJVMConnection, } from "../../wailsjs/go/app/App"; @@ -283,6 +284,69 @@ const supportsSSLForType = (type: string) => .toLowerCase(), ); +const sslCAPathSupportedTypes = new Set([ + "mysql", + "mariadb", + "oceanbase", + "diros", + "starrocks", + "sphinx", + "clickhouse", + "postgres", + "sqlserver", + "kingbase", + "highgo", + "vastbase", + "opengauss", + "mongodb", + "redis", +]); + +const sslClientCertificateSupportedTypes = new Set([ + "mysql", + "mariadb", + "oceanbase", + "diros", + "starrocks", + "sphinx", + "dameng", + "clickhouse", + "postgres", + "kingbase", + "highgo", + "vastbase", + "opengauss", + "mongodb", + "redis", +]); + +const supportsSSLCAPathForType = (type: string) => + sslCAPathSupportedTypes.has( + String(type || "") + .trim() + .toLowerCase(), + ); + +const supportsSSLClientCertificateForType = (type: string) => + sslClientCertificateSupportedTypes.has( + String(type || "") + .trim() + .toLowerCase(), + ); + +const isPostgresCompatibleSSLType = (type: string) => + [ + "postgres", + "kingbase", + "highgo", + "vastbase", + "opengauss", + ].includes( + String(type || "") + .trim() + .toLowerCase(), + ); + const isFileDatabaseType = (type: string) => type === "sqlite" || type === "duckdb"; @@ -394,6 +458,9 @@ const ConnectionModal: React.FC<{ const [driverStatusLoaded, setDriverStatusLoaded] = useState(false); const [selectingDbFile, setSelectingDbFile] = useState(false); const [selectingSSHKey, setSelectingSSHKey] = useState(false); + const [selectingCertificateField, setSelectingCertificateField] = useState< + "sslCAPath" | "sslCertPath" | "sslKeyPath" | null + >(null); const [clearSecrets, setClearSecrets] = useState( createEmptyConnectionSecretClearState, ); @@ -445,17 +512,24 @@ const ConnectionModal: React.FC<{ const isMySQLLike = isMySQLCompatibleType(dbType) && !isOceanBaseOracle; const supportsConnectionParams = supportsConnectionParamsForType(dbType); const isSSLType = supportsSSLForType(dbType); + const supportsSSLCAPath = supportsSSLCAPathForType(dbType); + const supportsSSLClientCertificate = + supportsSSLClientCertificateForType(dbType); const sslHintText = isMySQLLike - ? "当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL;本地自签证书场景可先用 Preferred 或 Skip Verify。" + ? "MySQL 兼容数据源支持 CA 证书、客户端证书与私钥;本地自签证书场景可先用 Preferred 或 Skip Verify。" : isOceanBaseOracle - ? "OceanBase Oracle 租户使用 Oracle 协议连接,SSL 参数按 Oracle 驱动规则传递。" + ? "OceanBase Oracle 租户使用 Oracle 协议连接;如需 Wallet,请在高级参数中配置 Oracle 驱动参数。" : dbType === "dameng" ? "达梦驱动启用 SSL 需要客户端证书与私钥路径(sslCertPath / sslKeyPath)。" : dbType === "sqlserver" - ? "SQL Server 推荐在生产环境使用 Required,并关闭 TrustServerCertificate。" + ? "SQL Server 可配置服务端证书/CA 文件;生产环境建议使用 Required,并关闭 TrustServerCertificate。" : dbType === "mongodb" - ? "MongoDB 可通过 TLS 保护连接,证书校验异常时可先用 Skip Verify 验证连通性。" - : "建议优先使用 Required;仅在测试环境或自签证书场景使用 Skip Verify。"; + ? "MongoDB 支持 CA 证书、客户端证书与私钥;证书校验异常时可先用 Skip Verify 验证连通性。" + : dbType === "oracle" + ? "Oracle PEM 证书请优先使用 Wallet 并在高级参数中配置 WALLET;这里仅控制 SSL 开关与校验策略。" + : dbType === "tdengine" + ? "TDengine 当前仅配置 WSS 与校验策略;证书文件请通过服务端信任链处理。" + : "支持的驱动可配置 CA 证书、客户端证书与私钥;仅在测试环境或自签证书场景使用 Skip Verify。"; const getSectionBg = (darkHex: string) => { if (!darkMode) { @@ -1339,10 +1413,102 @@ const ConnectionModal: React.FC<{ clickHouseProtocol: "http", useSSL: isHttps, sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable", + ...extractSSLPathValuesFromParams(parsed.params, "clickhouse"), connectionParams: serializeConnectionParams(parsed.params), }; }; + const firstConnectionParamValue = ( + params: URLSearchParams, + names: string[], + ): string => { + for (const name of names) { + const value = String(params.get(name) || "").trim(); + if (value) return value; + } + return ""; + }; + + const extractSSLPathValuesFromParams = ( + params: URLSearchParams, + type: string, + ): Record => { + const caPath = firstConnectionParamValue(params, [ + "sslCAPath", + "ssl_ca_path", + "sslrootcert", + "sslRootCert", + "tlsCAFile", + "caFile", + "certificate", + "servercertificate", + "serverCertificate", + ]); + const certPath = firstConnectionParamValue(params, [ + "sslCertPath", + "ssl_cert_path", + "SSL_CERT_PATH", + "sslcert", + "sslCert", + "tlsCertificateFile", + ]); + const keyPath = firstConnectionParamValue(params, [ + "sslKeyPath", + "ssl_key_path", + "SSL_KEY_PATH", + "sslkey", + "sslKey", + "tlsKeyFile", + ]); + return { + ...(supportsSSLCAPathForType(type) && caPath ? { sslCAPath: caPath } : {}), + ...(supportsSSLClientCertificateForType(type) && certPath ? { sslCertPath: certPath } : {}), + ...(supportsSSLClientCertificateForType(type) && keyPath ? { sslKeyPath: keyPath } : {}), + }; + }; + + const appendSSLPathParamsForUri = ( + params: URLSearchParams, + type: string, + values: Record, + ) => { + const caPath = String(values.sslCAPath || "").trim(); + const certPath = String(values.sslCertPath || "").trim(); + const keyPath = String(values.sslKeyPath || "").trim(); + const mode = String(values.sslMode || "preferred") + .trim() + .toLowerCase(); + if (supportsSSLCAPathForType(type) && caPath) { + if (isPostgresCompatibleSSLType(type)) { + if (mode !== "skip-verify" && mode !== "disable") { + params.set("sslrootcert", caPath); + } + } else if (type === "sqlserver") { + params.set("certificate", caPath); + } else { + params.set("sslCAPath", caPath); + } + } + if (supportsSSLClientCertificateForType(type) && certPath) { + if (type === "dameng") { + params.set("SSL_CERT_PATH", certPath); + } else if (isPostgresCompatibleSSLType(type)) { + params.set("sslcert", certPath); + } else { + params.set("sslCertPath", certPath); + } + } + if (supportsSSLClientCertificateForType(type) && keyPath) { + if (type === "dameng") { + params.set("SSL_KEY_PATH", keyPath); + } else if (isPostgresCompatibleSSLType(type)) { + params.set("sslkey", keyPath); + } else { + params.set("sslKeyPath", keyPath); + } + } + }; + const parseUriToValues = ( uriText: string, type: string, @@ -1419,6 +1585,7 @@ const ConnectionModal: React.FC<{ database: parsed.database || "", useSSL: sslMode !== "disable", sslMode, + ...extractSSLPathValuesFromParams(parsed.params, type), oceanBaseProtocol: parsedOceanBaseProtocol, mysqlTopology: parsedOceanBaseProtocol === "oracle" @@ -1491,6 +1658,7 @@ const ConnectionModal: React.FC<{ ? "skip-verify" : "required" : "disable", + ...extractSSLPathValuesFromParams(parsed.params, type), redisTopology: hostList.length > 1 || topologyParam === "cluster" ? "cluster" @@ -1564,6 +1732,7 @@ const ConnectionModal: React.FC<{ ? "skip-verify" : "required" : "disable", + ...extractSSLPathValuesFromParams(parsed.params, type), mongoTopology: hostList.length > 1 || !!parsed.params.get("replicaSet") ? "replica" @@ -1616,6 +1785,7 @@ const ConnectionModal: React.FC<{ } if (supportsSSLForType(type)) { + Object.assign(parsedValues, extractSSLPathValuesFromParams(parsed.params, type)); const normalizeBool = (raw: unknown) => { const text = String(raw ?? "") .trim() @@ -1891,6 +2061,7 @@ const ConnectionModal: React.FC<{ params.set("tls", "preferred"); } } + appendSSLPathParamsForUri(params, type, values); if (Number.isFinite(timeout) && timeout > 0) { params.set("timeout", String(timeout)); } @@ -1939,6 +2110,7 @@ const ConnectionModal: React.FC<{ params.set("skip_verify", "true"); } } + appendSSLPathParamsForUri(params, type, values); const query = params.toString(); const scheme = values.useSSL ? "rediss" : "redis"; return `${scheme}://${redisAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`; @@ -1997,6 +2169,7 @@ const ConnectionModal: React.FC<{ params.delete("tlsInsecure"); } } + appendSSLPathParamsForUri(params, type, values); if (Number.isFinite(timeout) && timeout > 0) { params.set("connectTimeoutMS", String(timeout * 1000)); params.set("serverSelectionTimeoutMS", String(timeout * 1000)); @@ -2025,20 +2198,23 @@ const ConnectionModal: React.FC<{ const mode = String(values.sslMode || "preferred") .trim() .toLowerCase(); - if ( - type === "postgres" || - type === "kingbase" || - type === "highgo" || - type === "vastbase" || - type === "opengauss" - ) { - params.set("sslmode", "require"); + if (isPostgresCompatibleSSLType(type)) { + params.set( + "sslmode", + mode === "skip-verify" + ? "require" + : String(values.sslCAPath || "").trim() + ? "verify-ca" + : "require", + ); + appendSSLPathParamsForUri(params, type, values); } else if (type === "sqlserver") { params.set("encrypt", "true"); params.set( "TrustServerCertificate", mode === "skip-verify" || mode === "preferred" ? "true" : "false", ); + appendSSLPathParamsForUri(params, type, values); } else if (type === "clickhouse") { if (clickHouseProtocol === "http") { if (mode === "skip-verify" || mode === "preferred") { @@ -2050,11 +2226,9 @@ const ConnectionModal: React.FC<{ params.set("skip_verify", "true"); } } + appendSSLPathParamsForUri(params, type, values); } else if (type === "dameng") { - const certPath = String(values.sslCertPath || "").trim(); - const keyPath = String(values.sslKeyPath || "").trim(); - if (certPath) params.set("SSL_CERT_PATH", certPath); - if (keyPath) params.set("SSL_KEY_PATH", keyPath); + appendSSLPathParamsForUri(params, type, values); } else if (type === "oracle") { params.set("SSL", "TRUE"); params.set("SSL VERIFY", mode === "required" ? "TRUE" : "FALSE"); @@ -2065,13 +2239,7 @@ const ConnectionModal: React.FC<{ } } } else if (supportsSSLForType(type)) { - if ( - type === "postgres" || - type === "kingbase" || - type === "highgo" || - type === "vastbase" || - type === "opengauss" - ) { + if (isPostgresCompatibleSSLType(type)) { params.set("sslmode", "disable"); } else if (type === "sqlserver") { params.set("encrypt", "disable"); @@ -2182,6 +2350,34 @@ const ConnectionModal: React.FC<{ } }; + const handleSelectCertificateFile = async ( + fieldName: "sslCAPath" | "sslCertPath" | "sslKeyPath", + certKind: "ca" | "client-cert" | "client-key", + ) => { + if (selectingCertificateField) { + return; + } + try { + setSelectingCertificateField(fieldName); + const currentPath = String(form.getFieldValue(fieldName) || "").trim(); + const res = await SelectCertificateFile(currentPath, certKind); + if (res?.success) { + const data = res.data || {}; + const selectedPath = + typeof data === "string" ? data : String(data.path || "").trim(); + if (selectedPath) { + form.setFieldValue(fieldName, selectedPath); + } + } else if (res?.message !== "已取消") { + message.error(`选择证书文件失败: ${res?.message || "未知错误"}`); + } + } catch (e: any) { + message.error(`选择证书文件失败: ${e?.message || String(e)}`); + } finally { + setSelectingCertificateField(null); + } + }; + const handleSelectDatabaseFile = async () => { if (selectingDbFile) { return; @@ -2317,6 +2513,7 @@ const ConnectionModal: React.FC<{ includeRedisDatabases: initialValues.includeRedisDatabases, useSSL: !!config.useSSL, sslMode: config.sslMode || "preferred", + sslCAPath: config.sslCAPath || "", sslCertPath: config.sslCertPath || "", sslKeyPath: config.sslKeyPath || "", useSSH: config.useSSH, @@ -3166,6 +3363,9 @@ const ConnectionModal: React.FC<{ ? "disable" : "preferred"; const effectiveUseSSL = sslCapableType && !!mergedValues.useSSL; + const sslCAPath = sslCapableType + ? String(mergedValues.sslCAPath || "").trim() + : ""; const sslCertPath = sslCapableType ? String(mergedValues.sslCertPath || "").trim() : ""; @@ -3175,6 +3375,9 @@ const ConnectionModal: React.FC<{ if (type === "dameng" && effectiveUseSSL && (!sslCertPath || !sslKeyPath)) { throw new Error("达梦启用 SSL 时必须填写证书路径与私钥路径"); } + if (effectiveUseSSL && supportsSSLClientCertificateForType(type) && (!!sslCertPath !== !!sslKeyPath)) { + throw new Error("TLS 客户端证书与私钥路径需要同时填写"); + } let primaryHost = "localhost"; let primaryPort = defaultPort; @@ -3369,6 +3572,7 @@ const ConnectionModal: React.FC<{ database: mergedValues.database || "", useSSL: effectiveUseSSL, sslMode: effectiveUseSSL ? sslMode : "disable", + sslCAPath: sslCAPath, sslCertPath: sslCertPath, sslKeyPath: sslKeyPath, useSSH: !!mergedValues.useSSH, @@ -3438,6 +3642,7 @@ const ConnectionModal: React.FC<{ database: "", useSSL: false, sslMode: undefined, + sslCAPath: undefined, sslCertPath: undefined, sslKeyPath: undefined, useSSH: false, @@ -3501,6 +3706,7 @@ const ConnectionModal: React.FC<{ database: "", useSSL: false, sslMode: "preferred", + sslCAPath: "", sslCertPath: "", sslKeyPath: "", useSSH: false, @@ -3551,6 +3757,7 @@ const ConnectionModal: React.FC<{ port: defaultPort, useSSL: sslCapableType ? false : undefined, sslMode: sslCapableType ? "preferred" : undefined, + sslCAPath: sslCapableType ? "" : undefined, sslCertPath: sslCapableType ? "" : undefined, sslKeyPath: sslCapableType ? "" : undefined, useHttpTunnel: false, @@ -5556,41 +5763,84 @@ const ConnectionModal: React.FC<{ ], })} - {dbType === "dameng" && ( - <> - - - - - - - + {(supportsSSLCAPath || supportsSSLClientCertificate) && ( +
+ {supportsSSLCAPath && ( + + + + + + + + + )} + {supportsSSLClientCertificate && ( + <> + + + + + + + + + + + + + + + + + + )} +
)} {sslHintText} @@ -6216,6 +6466,7 @@ const ConnectionModal: React.FC<{ user: "root", useSSL: false, sslMode: "preferred", + sslCAPath: "", sslCertPath: "", sslKeyPath: "", useSSH: false, diff --git a/frontend/src/store.test.ts b/frontend/src/store.test.ts index b41215f..fc2fb9c 100644 --- a/frontend/src/store.test.ts +++ b/frontend/src/store.test.ts @@ -262,6 +262,34 @@ describe('store appearance persistence', () => { expect(config?.port).toBe(9030); }); + it('preserves SSL certificate paths for SSL-capable saved connections', async () => { + const { useStore } = await importStore(); + + useStore.getState().replaceConnections([ + { + id: 'postgres-ssl', + name: 'Postgres SSL', + config: { + id: 'postgres-ssl', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + useSSL: true, + sslMode: 'required', + sslCAPath: 'C:/certs/ca.pem', + sslCertPath: 'C:/certs/client-cert.pem', + sslKeyPath: 'C:/certs/client-key.pem', + }, + }, + ]); + + const config = useStore.getState().connections[0]?.config; + expect(config?.sslCAPath).toBe('C:/certs/ca.pem'); + expect(config?.sslCertPath).toBe('C:/certs/client-cert.pem'); + expect(config?.sslKeyPath).toBe('C:/certs/client-key.pem'); + }); + it('normalizes OceanBase protocol override when replacing saved connections', async () => { const { useStore } = await importStore(); diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 876eb0c..a6e3283 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -529,6 +529,7 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => { database: toTrimmedString(raw.database), useSSL: sslCapable ? !!raw.useSSL : false, sslMode: sslCapable ? sslMode : "disable", + sslCAPath: sslCapable ? toTrimmedString(raw.sslCAPath) : "", sslCertPath: sslCapable ? toTrimmedString(raw.sslCertPath) : "", sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : "", useSSH: !!raw.useSSH, diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 031a8f6..89ba5a9 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -284,6 +284,7 @@ export interface ConnectionConfig { database?: string; useSSL?: boolean; sslMode?: "preferred" | "required" | "skip-verify" | "disable"; + sslCAPath?: string; sslCertPath?: string; sslKeyPath?: string; useSSH?: boolean; diff --git a/frontend/src/utils/connectionRpcConfig.test.ts b/frontend/src/utils/connectionRpcConfig.test.ts index b8f64b8..bbf69ad 100644 --- a/frontend/src/utils/connectionRpcConfig.test.ts +++ b/frontend/src/utils/connectionRpcConfig.test.ts @@ -148,6 +148,27 @@ describe('buildRpcConnectionConfig', () => { expect(result.connectionParams).toBe('characterEncoding=utf8&useSSL=false'); }); + it('preserves SSL certificate path fields for RPC calls', () => { + const result = buildRpcConnectionConfig({ + id: 'conn-postgres-ssl', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + useSSL: true, + sslMode: 'required', + sslCAPath: 'C:/certs/ca.pem', + sslCertPath: 'C:/certs/client-cert.pem', + sslKeyPath: 'C:/certs/client-key.pem', + } as any); + + expect(result.useSSL).toBe(true); + expect(result.sslMode).toBe('required'); + expect(result.sslCAPath).toBe('C:/certs/ca.pem'); + expect(result.sslCertPath).toBe('C:/certs/client-cert.pem'); + expect(result.sslKeyPath).toBe('C:/certs/client-key.pem'); + }); + it('fills default nested config blocks needed by RPC calls', () => { const result = buildRpcConnectionConfig({ id: 'conn-redis', diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index c30e837..999ee30 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -242,6 +242,8 @@ export function RenameTable(arg1:connection.ConnectionConfig,arg2:string,arg3:st export function RenameView(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise; +export function ResetWebViewZoom():Promise; + export function ResolveDriverDownloadDirectory(arg1:string):Promise; export function ResolveDriverPackageDownloadURL(arg1:string,arg2:string):Promise; @@ -256,6 +258,8 @@ export function SaveConnection(arg1:connection.SavedConnectionInput):Promise; +export function SelectCertificateFile(arg1:string,arg2:string):Promise; + export function SelectDataRootDirectory(arg1:string):Promise; export function SelectDatabaseFile(arg1:string,arg2:string):Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index 0f60a94..8665175 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -474,6 +474,10 @@ export function RenameView(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['RenameView'](arg1, arg2, arg3, arg4); } +export function ResetWebViewZoom() { + return window['go']['app']['App']['ResetWebViewZoom'](); +} + export function ResolveDriverDownloadDirectory(arg1) { return window['go']['app']['App']['ResolveDriverDownloadDirectory'](arg1); } @@ -502,6 +506,10 @@ export function SaveGlobalProxy(arg1) { return window['go']['app']['App']['SaveGlobalProxy'](arg1); } +export function SelectCertificateFile(arg1, arg2) { + return window['go']['app']['App']['SelectCertificateFile'](arg1, arg2); +} + export function SelectDataRootDirectory(arg1) { return window['go']['app']['App']['SelectDataRootDirectory'](arg1); } diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index e4dd23c..32611bb 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -673,6 +673,7 @@ export namespace connection { database: string; useSSL?: boolean; sslMode?: string; + sslCAPath?: string; sslCertPath?: string; sslKeyPath?: string; useSSH: boolean; @@ -718,6 +719,7 @@ export namespace connection { this.database = source["database"]; this.useSSL = source["useSSL"]; this.sslMode = source["sslMode"]; + this.sslCAPath = source["sslCAPath"]; this.sslCertPath = source["sslCertPath"]; this.sslKeyPath = source["sslKeyPath"]; this.useSSH = source["useSSH"]; diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 0432d3a..9ba929e 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -553,6 +553,61 @@ func (a *App) SelectSSHKeyFile(currentPath string) connection.QueryResult { return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}} } +func (a *App) SelectCertificateFile(currentPath string, certKind string) connection.QueryResult { + defaultDir := strings.TrimSpace(currentPath) + if defaultDir == "" { + if home, err := os.UserHomeDir(); err == nil { + defaultDir = home + } + } + if filepath.Ext(defaultDir) != "" { + defaultDir = filepath.Dir(defaultDir) + } + if defaultDir != "" && !filepath.IsAbs(defaultDir) { + if abs, err := filepath.Abs(defaultDir); err == nil { + defaultDir = abs + } + } + + kind := strings.ToLower(strings.TrimSpace(certKind)) + title := "选择 TLS 证书文件" + displayName := "证书文件" + switch kind { + case "ca": + title = "选择 CA/服务端证书文件" + case "client-cert": + title = "选择客户端证书文件" + case "client-key": + title = "选择客户端私钥文件" + displayName = "私钥文件" + } + + selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{ + Title: title, + DefaultDirectory: defaultDir, + Filters: []runtime.FileFilter{ + { + DisplayName: displayName, + Pattern: "*.pem;*.crt;*.cer;*.cert;*.key", + }, + { + DisplayName: "所有文件", + Pattern: "*", + }, + }, + }) + if err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + if strings.TrimSpace(selection) == "" { + return connection.QueryResult{Success: false, Message: "已取消"} + } + if abs, err := filepath.Abs(selection); err == nil { + selection = abs + } + return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}} +} + func (a *App) SelectDatabaseFile(currentPath string, driverType string) connection.QueryResult { defaultDir := strings.TrimSpace(currentPath) if defaultDir == "" { diff --git a/internal/connection/types.go b/internal/connection/types.go index 359ddf5..53a5c2f 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -89,6 +89,7 @@ type ConnectionConfig struct { Database string `json:"database"` UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable + SSLCAPath string `json:"sslCAPath,omitempty"` // TLS root CA / server certificate path SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng) SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng) UseSSH bool `json:"useSSH"` diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index cb2bcc3..7c3c3f3 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -167,7 +167,7 @@ func defaultClickHousePortForScheme(scheme string) int { } } -func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) *clickhouse.Options { +func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) (*clickhouse.Options, error) { connectTimeout := getConnectTimeout(config) readTimeout := connectTimeout if readTimeout < minClickHouseReadTimeout { @@ -187,11 +187,15 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig DialTimeout: connectTimeout, ReadTimeout: readTimeout, } - if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil { + tlsConfig, err := resolveGenericTLSConfig(config) + if err != nil { + return nil, err + } + if tlsConfig != nil { opts.TLS = tlsConfig } applyClickHouseConnectionParams(opts, config) - return opts + return opts, nil } func parseClickHouseDurationParam(raw string) (time.Duration, bool) { @@ -549,7 +553,14 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error { protocolConfig := withClickHouseProtocol(attempt, protocol) logger.Infof("ClickHouse 连接尝试:第%d组/%d 协议=%s 地址=%s:%d SSL=%t", idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL) - c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(protocolConfig)) + opts, err := c.buildClickHouseOptions(protocolConfig) + if err != nil { + failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败(protocol=%s): %v", idx+1, protocol.String(), err)) + logger.Warnf("ClickHouse TLS 配置失败:第%d组/%d 协议=%s 地址=%s:%d SSL=%t 原因=%v", + idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL, err) + continue + } + c.conn = clickhouse.OpenDB(opts) if err := c.Ping(); err != nil { failureMessage := clickHouseAttemptFailureMessage(protocol, err) failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %s", idx+1, protocol.String(), failureMessage)) diff --git a/internal/db/diros_impl.go b/internal/db/diros_impl.go index 38d7664..4fa0c27 100644 --- a/internal/db/diros_impl.go +++ b/internal/db/diros_impl.go @@ -139,7 +139,7 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) { protocol = netName } - return buildMySQLCompatibleDSN(config, protocol, address, database), nil + return buildMySQLCompatibleDSN(config, protocol, address, database) } func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { diff --git a/internal/db/dsn_test.go b/internal/db/dsn_test.go index 9a3a5f7..8c4953c 100644 --- a/internal/db/dsn_test.go +++ b/internal/db/dsn_test.go @@ -55,6 +55,78 @@ func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) { } } +func TestPostgresDSN_AppendsSSLPathParams(t *testing.T) { + p := &PostgresDB{} + cfg := connection.ConnectionConfig{ + Type: "postgres", + Host: "127.0.0.1", + Port: 5432, + User: "user", + Password: "pass", + Database: "db", + UseSSL: true, + SSLMode: "required", + SSLCAPath: "C:\\certs\\ca.pem", + SSLCertPath: "C:\\certs\\client-cert.pem", + SSLKeyPath: "C:\\certs\\client-key.pem", + } + + dsn := p.getDSN(cfg) + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("parse postgres dsn: %v", err) + } + query := parsed.Query() + if got := query.Get("sslmode"); got != "verify-ca" { + t.Fatalf("sslmode = %q, want verify-ca", got) + } + if got := query.Get("sslrootcert"); got != cfg.SSLCAPath { + t.Fatalf("sslrootcert = %q, want %q", got, cfg.SSLCAPath) + } + if got := query.Get("sslcert"); got != cfg.SSLCertPath { + t.Fatalf("sslcert = %q, want %q", got, cfg.SSLCertPath) + } + if got := query.Get("sslkey"); got != cfg.SSLKeyPath { + t.Fatalf("sslkey = %q, want %q", got, cfg.SSLKeyPath) + } +} + +func TestPostgresDSN_SkipVerifyOmitsSSLRootCert(t *testing.T) { + p := &PostgresDB{} + cfg := connection.ConnectionConfig{ + Type: "postgres", + Host: "127.0.0.1", + Port: 5432, + User: "user", + Password: "pass", + Database: "db", + UseSSL: true, + SSLMode: "skip-verify", + SSLCAPath: "C:\\certs\\ca.pem", + SSLCertPath: "C:\\certs\\client-cert.pem", + SSLKeyPath: "C:\\certs\\client-key.pem", + } + + dsn := p.getDSN(cfg) + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("parse postgres dsn: %v", err) + } + query := parsed.Query() + if got := query.Get("sslmode"); got != "require" { + t.Fatalf("sslmode = %q, want require", got) + } + if got := query.Get("sslrootcert"); got != "" { + t.Fatalf("sslrootcert should be omitted for skip-verify, got %q", got) + } + if got := query.Get("sslcert"); got != cfg.SSLCertPath { + t.Fatalf("sslcert = %q, want %q", got, cfg.SSLCertPath) + } + if got := query.Get("sslkey"); got != cfg.SSLKeyPath { + t.Fatalf("sslkey = %q, want %q", got, cfg.SSLKeyPath) + } +} + func TestPostgresDSN_MergesConnectionParams(t *testing.T) { p := &PostgresDB{} cfg := connection.ConnectionConfig{ @@ -109,6 +181,65 @@ func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) { } } +func TestMySQLDSN_UsesCustomTLSConfigWhenCertificatePathsAreConfigured(t *testing.T) { + m := &MySQLDB{} + cfg := connection.ConnectionConfig{ + Type: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Password: "pass", + Database: "db", + UseSSL: true, + SSLMode: "required", + SSLCAPath: "../../third_party/highgo-pq/certs/root.crt", + SSLCertPath: "../../third_party/highgo-pq/certs/postgresql.crt", + SSLKeyPath: "../../third_party/highgo-pq/certs/postgresql.key", + } + + dsn, err := m.getDSN(cfg) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + if strings.Contains(dsn, "tls=true") { + t.Fatalf("dsn 应使用自定义 TLS 配置名而不是 tls=true:%s", dsn) + } + if !strings.Contains(dsn, "tls=gonavi-") { + t.Fatalf("dsn 缺少自定义 TLS 配置名:%s", dsn) + } + if strings.Contains(dsn, "allowFallbackToPlaintext=true") { + t.Fatalf("required 模式不应启用明文回退:%s", dsn) + } +} + +func TestMySQLDSN_PreservesPreferredFallbackWithCustomTLSConfig(t *testing.T) { + m := &MySQLDB{} + cfg := connection.ConnectionConfig{ + Type: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Password: "pass", + Database: "db", + UseSSL: true, + SSLMode: "preferred", + SSLCAPath: "../../third_party/highgo-pq/certs/root.crt", + SSLCertPath: "../../third_party/highgo-pq/certs/postgresql.crt", + SSLKeyPath: "../../third_party/highgo-pq/certs/postgresql.key", + } + + dsn, err := m.getDSN(cfg) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + if !strings.Contains(dsn, "tls=gonavi-") { + t.Fatalf("dsn 缺少自定义 TLS 配置名:%s", dsn) + } + if !strings.Contains(dsn, "allowFallbackToPlaintext=true") { + t.Fatalf("preferred 自定义 TLS 配置应保留明文回退:%s", dsn) + } +} + func TestOracleDSN_EscapesUserAndPassword(t *testing.T) { o := &OracleDB{} cfg := connection.ConnectionConfig{ @@ -380,6 +511,30 @@ func TestSQLServerDSN_MergesConnectionParams(t *testing.T) { } } +func TestSQLServerDSN_AppendsCertificateParam(t *testing.T) { + s := &SqlServerDB{} + cfg := connection.ConnectionConfig{ + Type: "sqlserver", + Host: "127.0.0.1", + Port: 1433, + User: "sa", + Password: "pass", + Database: "master", + UseSSL: true, + SSLMode: "required", + SSLCAPath: "C:\\certs\\sqlserver-ca.pem", + } + + dsn := s.getDSN(cfg) + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("parse sqlserver dsn: %v", err) + } + if got := parsed.Query().Get("certificate"); got != cfg.SSLCAPath { + t.Fatalf("certificate = %q, want %q", got, cfg.SSLCAPath) + } +} + func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) { c := &ClickHouseDB{} cfg := normalizeClickHouseConfig(connection.ConnectionConfig{ @@ -392,7 +547,10 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) { Timeout: 15, }) - opts := c.buildClickHouseOptions(cfg) + opts, err := c.buildClickHouseOptions(cfg) + if err != nil { + t.Fatalf("buildClickHouseOptions failed: %v", err) + } if opts == nil { t.Fatal("options 为空") } @@ -438,7 +596,10 @@ func TestClickHouseOptions_MergesConnectionParamsIntoOptionsAndSettings(t *testi ConnectionParams: "max_execution_time=60&compress=lz4&read_timeout=10s", }) - opts := c.buildClickHouseOptions(cfg) + opts, err := c.buildClickHouseOptions(cfg) + if err != nil { + t.Fatalf("buildClickHouseOptions failed: %v", err) + } if opts == nil { t.Fatal("options 为空") } @@ -465,7 +626,10 @@ func TestClickHouseOptions_ReadTimeoutUsesLargerConfiguredTimeout(t *testing.T) Timeout: 900, }) - opts := c.buildClickHouseOptions(cfg) + opts, err := c.buildClickHouseOptions(cfg) + if err != nil { + t.Fatalf("buildClickHouseOptions failed: %v", err) + } if opts == nil { t.Fatal("options 为空") } diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index 4c53fab..a8455f3 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -43,6 +43,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string { u.User = url.UserPassword(config.User, config.Password) q := url.Values{} q.Set("sslmode", resolvePostgresSSLMode(config)) + applyPostgresSSLPathParams(q, config) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) mergeConnectionParamsFromConfigWithAllowlist(q, config, highGoConnectionParamNames, "postgres", "postgresql", "highgo") u.RawQuery = q.Encode() diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 1d7f660..2e4d498 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -70,10 +70,11 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { params.Set("password", config.Password) params.Set("dbname", config.Database) params.Set("sslmode", resolvePostgresSSLMode(config)) + applyPostgresSSLPathParams(params, config) params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) mergeConnectionParamsFromConfigWithAllowlist(params, config, kingbaseConnectionParamNames, "kingbase") - preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"} + preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "sslrootcert", "sslcert", "sslkey", "connect_timeout"} seen := make(map[string]struct{}, len(params)) parts := make([]string, 0, len(params)) for _, key := range preferred { diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 8ab24bb..719caee 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -36,7 +36,7 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) { protocol = netName } - return buildMySQLCompatibleDSN(config, protocol, address, database), nil + return buildMySQLCompatibleDSN(config, protocol, address, database) } func (m *MariaDB) Connect(config connection.ConnectionConfig) error { diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index b336c01..82daf6c 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -4,7 +4,6 @@ package db import ( "context" - "crypto/tls" "fmt" "net" "net/url" @@ -414,12 +413,15 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { } uri := m.getURI(attemptConfig) clientOpts := options.Client().ApplyURI(uri) - tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig) - if tlsEnabled { - clientOpts.SetTLSConfig(&tls.Config{ - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: tlsInsecure, - }) + tlsConfig, tlsErr := resolveGenericTLSConfig(attemptConfig) + if tlsErr != nil { + detail := fmt.Sprintf("%s %sTLS 配置失败: %v", sslLabel, authLabel, tlsErr) + errorDetails = append(errorDetails, detail) + logger.Warnf("MongoDB TLS 配置失败:%d/%d 模式=%s 凭据=%s 错误=%v", attemptNo, totalAttempts, sslLabel, authLabel, tlsErr) + continue + } + if tlsConfig != nil { + clientOpts.SetTLSConfig(tlsConfig) } if attemptConfig.UseProxy { clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy}) diff --git a/internal/db/mongodb_impl_v1.go b/internal/db/mongodb_impl_v1.go index d6472d8..2b610b0 100644 --- a/internal/db/mongodb_impl_v1.go +++ b/internal/db/mongodb_impl_v1.go @@ -4,7 +4,6 @@ package db import ( "context" - "crypto/tls" "fmt" "net" "net/url" @@ -415,12 +414,15 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { } uri := m.getURI(attemptConfig) clientOpts := options.Client().ApplyURI(uri) - tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig) - if tlsEnabled { - clientOpts.SetTLSConfig(&tls.Config{ - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: tlsInsecure, - }) + tlsConfig, tlsErr := resolveGenericTLSConfig(attemptConfig) + if tlsErr != nil { + detail := fmt.Sprintf("%s %sTLS 配置失败: %v", sslLabel, authLabel, tlsErr) + errorDetails = append(errorDetails, detail) + logger.Warnf("MongoDB TLS 配置失败:%d/%d 模式=%s 凭据=%s 错误=%v", attemptNo, totalAttempts, sslLabel, authLabel, tlsErr) + continue + } + if tlsConfig != nil { + clientOpts.SetTLSConfig(tlsConfig) } if attemptConfig.UseProxy { clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy}) diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index e2dd5f4..895b728 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -17,7 +17,7 @@ import ( "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" - _ "github.com/go-sql-driver/mysql" + mysql "github.com/go-sql-driver/mysql" ) type MySQLDB struct { @@ -115,19 +115,19 @@ var mysqlSupportedDriverParamNames = map[string]string{ // OceanBase Oracle 租户 MySQL wire 路径用它注入 OBClient 私有 capability attribute; // 普通 mysql/mariadb 用户也能在此声明 program_name 等元数据。 "connectionattributes": "connectionAttributes", - "interpolateparams": "interpolateParams", - "loc": "loc", - "maxallowedpacket": "maxAllowedPacket", - "multistatements": "multiStatements", - "parsetime": "parseTime", - "readtimeout": "readTimeout", - "rejectreadonly": "rejectReadOnly", - "serverpubkey": "serverPubKey", - "sql_mode": "sql_mode", - "timetruncate": "timeTruncate", - "timeout": "timeout", - "tls": "tls", - "writetimeout": "writeTimeout", + "interpolateparams": "interpolateParams", + "loc": "loc", + "maxallowedpacket": "maxAllowedPacket", + "multistatements": "multiStatements", + "parsetime": "parseTime", + "readtimeout": "readTimeout", + "rejectreadonly": "rejectReadOnly", + "serverpubkey": "serverPubKey", + "sql_mode": "sql_mode", + "timetruncate": "timeTruncate", + "timeout": "timeout", + "tls": "tls", + "writetimeout": "writeTimeout", } var mysqlBoolDriverParamNames = map[string]struct{}{ @@ -274,15 +274,40 @@ func mergeMySQLConnectionParams(params url.Values, values url.Values) { } } -func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string { +func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, error) { + mode := resolveMySQLTLSMode(config) + if mode == "false" || !hasTLSCertificatePaths(config) { + return mode, false, nil + } + tlsConfig, err := resolveGenericTLSConfig(config) + if err != nil { + return "", false, err + } + if tlsConfig == nil { + return mode, false, nil + } + name := mysqlTLSConfigName(config) + if err := mysql.RegisterTLSConfig(name, tlsConfig); err != nil && !strings.Contains(strings.ToLower(err.Error()), "already registered") { + return "", false, fmt.Errorf("注册 MySQL TLS 证书配置失败:%w", err) + } + return name, normalizeSSLModeValue(config.SSLMode) == sslModePreferred, nil +} + +func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) { timeout := getConnectTimeoutSeconds(config) - tlsMode := resolveMySQLTLSMode(config) + tlsMode, allowFallbackToPlaintext, err := resolveMySQLTLSParam(config) + if err != nil { + return "", err + } params := url.Values{} params.Set("charset", "utf8mb4") params.Set("parseTime", "True") params.Set("loc", "Local") params.Set("timeout", fmt.Sprintf("%ds", timeout)) params.Set("tls", tlsMode) + if allowFallbackToPlaintext { + params.Set("allowFallbackToPlaintext", "true") + } params.Set("multiStatements", "true") if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok { mergeMySQLConnectionParams(params, parsed.Query()) @@ -291,7 +316,7 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre return fmt.Sprintf( "%s:%s@%s(%s)/%s?%s", config.User, config.Password, protocol, address, database, params.Encode(), - ) + ), nil } func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) { @@ -502,7 +527,7 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) { protocol = netName } - return buildMySQLCompatibleDSN(config, protocol, address, database), nil + return buildMySQLCompatibleDSN(config, protocol, address, database) } func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { diff --git a/internal/db/oceanbase_impl.go b/internal/db/oceanbase_impl.go index 927d2ea..68f44fd 100644 --- a/internal/db/oceanbase_impl.go +++ b/internal/db/oceanbase_impl.go @@ -185,7 +185,7 @@ func (o *OceanBaseDB) getDSN(config connection.ConnectionConfig) (string, error) protocol = netName } - return buildMySQLCompatibleDSN(config, protocol, address, database), nil + return buildMySQLCompatibleDSN(config, protocol, address, database) } func normalizeOceanBaseProtocol(raw string) string { diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index fcbc7c4..d746f01 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -63,6 +63,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string { u.User = url.UserPassword(config.User, config.Password) q := url.Values{} q.Set("sslmode", resolvePostgresSSLMode(config)) + applyPostgresSSLPathParams(q, config) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "opengauss") u.RawQuery = q.Encode() diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index 57f5cdb..2d45384 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -50,6 +50,9 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string { encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config) q.Set("encrypt", encrypt) q.Set("trustservercertificate", trustServerCertificate) + if strings.TrimSpace(config.SSLCAPath) != "" { + q.Set("certificate", strings.TrimSpace(config.SSLCAPath)) + } mergeConnectionParamsFromConfigWithAllowlist(q, config, sqlServerConnectionParamNames, "sqlserver") u.RawQuery = q.Encode() diff --git a/internal/db/ssl_mode.go b/internal/db/ssl_mode.go index 050db53..1686340 100644 --- a/internal/db/ssl_mode.go +++ b/internal/db/ssl_mode.go @@ -1,10 +1,13 @@ package db import ( + "crypto/sha256" "crypto/tls" + "encoding/hex" "strings" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/tlsconfig" ) const ( @@ -61,15 +64,37 @@ func resolveMySQLTLSMode(config connection.ConnectionConfig) string { } } +func hasTLSCertificatePaths(config connection.ConnectionConfig) bool { + return strings.TrimSpace(config.SSLCAPath) != "" || + strings.TrimSpace(config.SSLCertPath) != "" || + strings.TrimSpace(config.SSLKeyPath) != "" +} + +func mysqlTLSConfigName(config connection.ConnectionConfig) string { + sum := sha256.Sum256([]byte(strings.Join([]string{ + normalizedSSLMode(config), + strings.TrimSpace(config.SSLCAPath), + strings.TrimSpace(config.SSLCertPath), + strings.TrimSpace(config.SSLKeyPath), + }, "\x00"))) + return "gonavi-" + hex.EncodeToString(sum[:8]) +} + func resolvePostgresSSLMode(config connection.ConnectionConfig) string { switch normalizedSSLMode(config) { case sslModeDisable: return "disable" case sslModeRequired: + if strings.TrimSpace(config.SSLCAPath) != "" { + return "verify-ca" + } return "require" case sslModeSkipVerify: return "require" default: + if strings.TrimSpace(config.SSLCAPath) != "" { + return "verify-ca" + } return "require" } } @@ -87,17 +112,47 @@ func resolveSQLServerTLSSettings(config connection.ConnectionConfig) (encrypt st } } -func resolveGenericTLSConfig(config connection.ConnectionConfig) *tls.Config { +func applyPostgresSSLPathParams(params interface{ Set(string, string) }, config connection.ConnectionConfig) { + mode := normalizedSSLMode(config) + if mode != sslModeDisable && mode != sslModeSkipVerify && strings.TrimSpace(config.SSLCAPath) != "" { + params.Set("sslrootcert", strings.TrimSpace(config.SSLCAPath)) + } + if mode != sslModeDisable && strings.TrimSpace(config.SSLCertPath) != "" { + params.Set("sslcert", strings.TrimSpace(config.SSLCertPath)) + } + if mode != sslModeDisable && strings.TrimSpace(config.SSLKeyPath) != "" { + params.Set("sslkey", strings.TrimSpace(config.SSLKeyPath)) + } +} + +func resolveGenericTLSConfig(config connection.ConnectionConfig) (*tls.Config, error) { switch normalizedSSLMode(config) { case sslModeDisable: - return nil + return nil, nil case sslModeRequired: - return &tls.Config{MinVersion: tls.VersionTLS12} + return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{ + Enabled: true, + CAPath: config.SSLCAPath, + CertPath: config.SSLCertPath, + KeyPath: config.SSLKeyPath, + }) case sslModeSkipVerify: - return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{ + Enabled: true, + InsecureSkipVerify: true, + CAPath: config.SSLCAPath, + CertPath: config.SSLCertPath, + KeyPath: config.SSLKeyPath, + }) default: // Preferred: 先尝试 TLS(为提升兼容性默认跳过证书校验),失败时由调用方按需回退明文。 - return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{ + Enabled: true, + InsecureSkipVerify: true, + CAPath: config.SSLCAPath, + CertPath: config.SSLCertPath, + KeyPath: config.SSLKeyPath, + }) } } diff --git a/internal/db/starrocks_impl.go b/internal/db/starrocks_impl.go index f25167a..af2df6d 100644 --- a/internal/db/starrocks_impl.go +++ b/internal/db/starrocks_impl.go @@ -214,7 +214,7 @@ func (s *StarRocksDB) getDSN(config connection.ConnectionConfig) (string, error) protocol = netName } - return buildMySQLCompatibleDSN(config, protocol, address, database), nil + return buildMySQLCompatibleDSN(config, protocol, address, database) } func resolveStarRocksCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go index bd2b501..ef0c497 100644 --- a/internal/db/vastbase_impl.go +++ b/internal/db/vastbase_impl.go @@ -42,6 +42,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string { u.User = url.UserPassword(config.User, config.Password) q := url.Values{} q.Set("sslmode", resolvePostgresSSLMode(config)) + applyPostgresSSLPathParams(q, config) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "vastbase") u.RawQuery = q.Encode() diff --git a/internal/redis/redis_impl.go b/internal/redis/redis_impl.go index 677bbf6..e45cac9 100644 --- a/internal/redis/redis_impl.go +++ b/internal/redis/redis_impl.go @@ -277,7 +277,10 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error { var failures []string for idx, attempt := range attempts { var tlsConfig *tls.Config - if cfg := resolveRedisTLSConfig(attempt); cfg != nil { + if cfg, err := resolveRedisTLSConfig(attempt); err != nil { + failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败: %v", idx+1, err)) + continue + } else if cfg != nil { if host, _, err := net.SplitHostPort(seedAddrs[0]); err == nil && host != "" { cfg.ServerName = host } @@ -332,7 +335,10 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error { var failures []string for idx, attempt := range attempts { var tlsConfig *tls.Config - if cfg := resolveRedisTLSConfig(attempt); cfg != nil { + if cfg, err := resolveRedisTLSConfig(attempt); err != nil { + failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败: %v", idx+1, err)) + continue + } else if cfg != nil { if host, _, err := net.SplitHostPort(addr); err == nil && host != "" { cfg.ServerName = host } diff --git a/internal/redis/ssl_mode.go b/internal/redis/ssl_mode.go index 11f3e9f..1e78034 100644 --- a/internal/redis/ssl_mode.go +++ b/internal/redis/ssl_mode.go @@ -5,6 +5,7 @@ import ( "strings" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/tlsconfig" ) func normalizeRedisSSLMode(raw string) string { @@ -41,15 +42,32 @@ func withRedisSSLDisabled(config connection.ConnectionConfig) connection.Connect return next } -func resolveRedisTLSConfig(config connection.ConnectionConfig) *tls.Config { +func resolveRedisTLSConfig(config connection.ConnectionConfig) (*tls.Config, error) { switch redisSSLMode(config) { case "disable": - return nil + return nil, nil case "required": - return &tls.Config{MinVersion: tls.VersionTLS12} + return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{ + Enabled: true, + CAPath: config.SSLCAPath, + CertPath: config.SSLCertPath, + KeyPath: config.SSLKeyPath, + }) case "skip-verify": - return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{ + Enabled: true, + InsecureSkipVerify: true, + CAPath: config.SSLCAPath, + CertPath: config.SSLCertPath, + KeyPath: config.SSLKeyPath, + }) default: - return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} + return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{ + Enabled: true, + InsecureSkipVerify: true, + CAPath: config.SSLCAPath, + CertPath: config.SSLCertPath, + KeyPath: config.SSLKeyPath, + }) } } diff --git a/internal/tlsconfig/tlsconfig.go b/internal/tlsconfig/tlsconfig.go new file mode 100644 index 0000000..9644e76 --- /dev/null +++ b/internal/tlsconfig/tlsconfig.go @@ -0,0 +1,62 @@ +package tlsconfig + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "strings" +) + +type ClientConfigOptions struct { + Enabled bool + InsecureSkipVerify bool + CAPath string + CertPath string + KeyPath string +} + +func BuildClientConfig(options ClientConfigOptions) (*tls.Config, error) { + if !options.Enabled { + return nil, nil + } + + cfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: options.InsecureSkipVerify, + } + + caPath := strings.TrimSpace(options.CAPath) + if caPath != "" { + pemBytes, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("读取 TLS CA 证书失败(%s):%w", caPath, err) + } + pool := x509.NewCertPool() + if ok := pool.AppendCertsFromPEM(pemBytes); !ok { + certs, err := x509.ParseCertificates(pemBytes) + if err != nil || len(certs) == 0 { + return nil, fmt.Errorf("TLS CA 证书不是有效的 PEM/DER 文件:%s", caPath) + } + for _, cert := range certs { + pool.AddCert(cert) + } + } + cfg.RootCAs = pool + } + + certPath := strings.TrimSpace(options.CertPath) + keyPath := strings.TrimSpace(options.KeyPath) + if (certPath == "") != (keyPath == "") { + return nil, fmt.Errorf("TLS 客户端证书和私钥需要同时配置") + } + if certPath != "" { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("加载 TLS 客户端证书失败(cert=%s key=%s):%w", certPath, keyPath, err) + } + cfg.Certificates = []tls.Certificate{cert} + } + + return cfg, nil +} diff --git a/internal/tlsconfig/tlsconfig_test.go b/internal/tlsconfig/tlsconfig_test.go new file mode 100644 index 0000000..c6348a4 --- /dev/null +++ b/internal/tlsconfig/tlsconfig_test.go @@ -0,0 +1,126 @@ +package tlsconfig + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" +) + +func TestBuildClientConfigLoadsCAAndClientCertificate(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, _ := writeSelfSignedCertificate(t, dir) + + cfg, err := BuildClientConfig(ClientConfigOptions{ + Enabled: true, + CAPath: certPath, + CertPath: certPath, + KeyPath: keyPath, + }) + if err != nil { + t.Fatalf("BuildClientConfig failed: %v", err) + } + if cfg == nil { + t.Fatal("config is nil") + } + if cfg.RootCAs == nil { + t.Fatal("RootCAs is nil") + } + if len(cfg.Certificates) != 1 { + t.Fatalf("Certificates length = %d, want 1", len(cfg.Certificates)) + } +} + +func TestBuildClientConfigLoadsDERCA(t *testing.T) { + dir := t.TempDir() + _, _, derBytes := writeSelfSignedCertificate(t, dir) + derPath := filepath.Join(dir, "ca.cer") + if err := os.WriteFile(derPath, derBytes, 0600); err != nil { + t.Fatalf("write der cert: %v", err) + } + + cfg, err := BuildClientConfig(ClientConfigOptions{ + Enabled: true, + CAPath: derPath, + }) + if err != nil { + t.Fatalf("BuildClientConfig failed: %v", err) + } + if cfg == nil { + t.Fatal("config is nil") + } + if cfg.RootCAs == nil { + t.Fatal("RootCAs is nil") + } +} + +func TestBuildClientConfigRequiresClientCertificateAndKeyTogether(t *testing.T) { + _, err := BuildClientConfig(ClientConfigOptions{ + Enabled: true, + CertPath: "client.pem", + }) + if err == nil { + t.Fatal("expected error") + } +} + +func writeSelfSignedCertificate(t *testing.T, dir string) (string, string, []byte) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "GoNavi Test", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("create certificate: %v", err) + } + + certPath := filepath.Join(dir, "cert.pem") + certFile, err := os.Create(certPath) + if err != nil { + t.Fatalf("create cert file: %v", err) + } + if err := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + _ = certFile.Close() + t.Fatalf("write cert: %v", err) + } + if err := certFile.Close(); err != nil { + t.Fatalf("close cert file: %v", err) + } + + keyPath := filepath.Join(dir, "key.pem") + keyFile, err := os.Create(keyPath) + if err != nil { + t.Fatalf("create key file: %v", err) + } + keyBytes := x509.MarshalPKCS1PrivateKey(privateKey) + if err := pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}); err != nil { + _ = keyFile.Close() + t.Fatalf("write key: %v", err) + } + if err := keyFile.Close(); err != nil { + t.Fatalf("close key file: %v", err) + } + + return certPath, keyPath, derBytes +}