feat(connection): 支持连接 SSL 证书文件配置

- 新增 CA 证书、客户端证书和私钥路径配置
- 为 MySQL、PostgreSQL、ClickHouse、MongoDB、Redis 等连接接入 TLS 证书
- 修正 SSL 模式下证书校验、明文回退和 DER 证书兼容问题
- 补充证书路径保存、RPC 传递和 DSN 生成回归测试
Refs #463
This commit is contained in:
Syngnat
2026-05-15 22:04:20 +08:00
parent acb119d80e
commit b707c74203
29 changed files with 965 additions and 115 deletions

View File

@@ -92,6 +92,7 @@ import {
TestConnection,
RedisConnect,
SelectDatabaseFile,
SelectCertificateFile,
SelectSSHKeyFile,
TestJVMConnection,
} from "../../wailsjs/go/app/App";
@@ -283,6 +284,69 @@ const supportsSSLForType = (type: string) =>
.toLowerCase(),
);
const sslCAPathSupportedTypes = new Set([
"mysql",
"mariadb",
"oceanbase",
"diros",
"starrocks",
"sphinx",
"clickhouse",
"postgres",
"sqlserver",
"kingbase",
"highgo",
"vastbase",
"opengauss",
"mongodb",
"redis",
]);
const sslClientCertificateSupportedTypes = new Set([
"mysql",
"mariadb",
"oceanbase",
"diros",
"starrocks",
"sphinx",
"dameng",
"clickhouse",
"postgres",
"kingbase",
"highgo",
"vastbase",
"opengauss",
"mongodb",
"redis",
]);
const supportsSSLCAPathForType = (type: string) =>
sslCAPathSupportedTypes.has(
String(type || "")
.trim()
.toLowerCase(),
);
const supportsSSLClientCertificateForType = (type: string) =>
sslClientCertificateSupportedTypes.has(
String(type || "")
.trim()
.toLowerCase(),
);
const isPostgresCompatibleSSLType = (type: string) =>
[
"postgres",
"kingbase",
"highgo",
"vastbase",
"opengauss",
].includes(
String(type || "")
.trim()
.toLowerCase(),
);
const isFileDatabaseType = (type: string) =>
type === "sqlite" || type === "duckdb";
@@ -394,6 +458,9 @@ const ConnectionModal: React.FC<{
const [driverStatusLoaded, setDriverStatusLoaded] = useState(false);
const [selectingDbFile, setSelectingDbFile] = useState(false);
const [selectingSSHKey, setSelectingSSHKey] = useState(false);
const [selectingCertificateField, setSelectingCertificateField] = useState<
"sslCAPath" | "sslCertPath" | "sslKeyPath" | null
>(null);
const [clearSecrets, setClearSecrets] = useState<ConnectionSecretClearState>(
createEmptyConnectionSecretClearState,
);
@@ -445,17 +512,24 @@ const ConnectionModal: React.FC<{
const isMySQLLike = isMySQLCompatibleType(dbType) && !isOceanBaseOracle;
const supportsConnectionParams = supportsConnectionParamsForType(dbType);
const isSSLType = supportsSSLForType(dbType);
const supportsSSLCAPath = supportsSSLCAPathForType(dbType);
const supportsSSLClientCertificate =
supportsSSLClientCertificateForType(dbType);
const sslHintText = isMySQLLike
? "MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL;本地自签证书场景可先用 Preferred 或 Skip Verify。"
? "MySQL 兼容数据源支持 CA 证书、客户端证书与私钥;本地自签证书场景可先用 Preferred 或 Skip Verify。"
: isOceanBaseOracle
? "OceanBase Oracle 租户使用 Oracle 协议连接SSL 参数按 Oracle 驱动规则传递。"
? "OceanBase Oracle 租户使用 Oracle 协议连接;如需 Wallet请在高级参数中配置 Oracle 驱动参数。"
: dbType === "dameng"
? "达梦驱动启用 SSL 需要客户端证书与私钥路径sslCertPath / sslKeyPath。"
: dbType === "sqlserver"
? "SQL Server 推荐在生产环境使用 Required并关闭 TrustServerCertificate。"
? "SQL Server 可配置服务端证书/CA 文件;生产环境建议使用 Required并关闭 TrustServerCertificate。"
: dbType === "mongodb"
? "MongoDB 可通过 TLS 保护连接,证书校验异常时可先用 Skip Verify 验证连通性。"
: "建议优先使用 Required仅在测试环境或自签证书场景使用 Skip Verify。";
? "MongoDB 支持 CA 证书、客户端证书与私钥;证书校验异常时可先用 Skip Verify 验证连通性。"
: dbType === "oracle"
? "Oracle PEM 证书请优先使用 Wallet 并在高级参数中配置 WALLET这里仅控制 SSL 开关与校验策略。"
: dbType === "tdengine"
? "TDengine 当前仅配置 WSS 与校验策略;证书文件请通过服务端信任链处理。"
: "支持的驱动可配置 CA 证书、客户端证书与私钥;仅在测试环境或自签证书场景使用 Skip Verify。";
const getSectionBg = (darkHex: string) => {
if (!darkMode) {
@@ -1339,10 +1413,102 @@ const ConnectionModal: React.FC<{
clickHouseProtocol: "http",
useSSL: isHttps,
sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable",
...extractSSLPathValuesFromParams(parsed.params, "clickhouse"),
connectionParams: serializeConnectionParams(parsed.params),
};
};
const firstConnectionParamValue = (
params: URLSearchParams,
names: string[],
): string => {
for (const name of names) {
const value = String(params.get(name) || "").trim();
if (value) return value;
}
return "";
};
const extractSSLPathValuesFromParams = (
params: URLSearchParams,
type: string,
): Record<string, string> => {
const caPath = firstConnectionParamValue(params, [
"sslCAPath",
"ssl_ca_path",
"sslrootcert",
"sslRootCert",
"tlsCAFile",
"caFile",
"certificate",
"servercertificate",
"serverCertificate",
]);
const certPath = firstConnectionParamValue(params, [
"sslCertPath",
"ssl_cert_path",
"SSL_CERT_PATH",
"sslcert",
"sslCert",
"tlsCertificateFile",
]);
const keyPath = firstConnectionParamValue(params, [
"sslKeyPath",
"ssl_key_path",
"SSL_KEY_PATH",
"sslkey",
"sslKey",
"tlsKeyFile",
]);
return {
...(supportsSSLCAPathForType(type) && caPath ? { sslCAPath: caPath } : {}),
...(supportsSSLClientCertificateForType(type) && certPath ? { sslCertPath: certPath } : {}),
...(supportsSSLClientCertificateForType(type) && keyPath ? { sslKeyPath: keyPath } : {}),
};
};
const appendSSLPathParamsForUri = (
params: URLSearchParams,
type: string,
values: Record<string, any>,
) => {
const caPath = String(values.sslCAPath || "").trim();
const certPath = String(values.sslCertPath || "").trim();
const keyPath = String(values.sslKeyPath || "").trim();
const mode = String(values.sslMode || "preferred")
.trim()
.toLowerCase();
if (supportsSSLCAPathForType(type) && caPath) {
if (isPostgresCompatibleSSLType(type)) {
if (mode !== "skip-verify" && mode !== "disable") {
params.set("sslrootcert", caPath);
}
} else if (type === "sqlserver") {
params.set("certificate", caPath);
} else {
params.set("sslCAPath", caPath);
}
}
if (supportsSSLClientCertificateForType(type) && certPath) {
if (type === "dameng") {
params.set("SSL_CERT_PATH", certPath);
} else if (isPostgresCompatibleSSLType(type)) {
params.set("sslcert", certPath);
} else {
params.set("sslCertPath", certPath);
}
}
if (supportsSSLClientCertificateForType(type) && keyPath) {
if (type === "dameng") {
params.set("SSL_KEY_PATH", keyPath);
} else if (isPostgresCompatibleSSLType(type)) {
params.set("sslkey", keyPath);
} else {
params.set("sslKeyPath", keyPath);
}
}
};
const parseUriToValues = (
uriText: string,
type: string,
@@ -1419,6 +1585,7 @@ const ConnectionModal: React.FC<{
database: parsed.database || "",
useSSL: sslMode !== "disable",
sslMode,
...extractSSLPathValuesFromParams(parsed.params, type),
oceanBaseProtocol: parsedOceanBaseProtocol,
mysqlTopology:
parsedOceanBaseProtocol === "oracle"
@@ -1491,6 +1658,7 @@ const ConnectionModal: React.FC<{
? "skip-verify"
: "required"
: "disable",
...extractSSLPathValuesFromParams(parsed.params, type),
redisTopology:
hostList.length > 1 || topologyParam === "cluster"
? "cluster"
@@ -1564,6 +1732,7 @@ const ConnectionModal: React.FC<{
? "skip-verify"
: "required"
: "disable",
...extractSSLPathValuesFromParams(parsed.params, type),
mongoTopology:
hostList.length > 1 || !!parsed.params.get("replicaSet")
? "replica"
@@ -1616,6 +1785,7 @@ const ConnectionModal: React.FC<{
}
if (supportsSSLForType(type)) {
Object.assign(parsedValues, extractSSLPathValuesFromParams(parsed.params, type));
const normalizeBool = (raw: unknown) => {
const text = String(raw ?? "")
.trim()
@@ -1891,6 +2061,7 @@ const ConnectionModal: React.FC<{
params.set("tls", "preferred");
}
}
appendSSLPathParamsForUri(params, type, values);
if (Number.isFinite(timeout) && timeout > 0) {
params.set("timeout", String(timeout));
}
@@ -1939,6 +2110,7 @@ const ConnectionModal: React.FC<{
params.set("skip_verify", "true");
}
}
appendSSLPathParamsForUri(params, type, values);
const query = params.toString();
const scheme = values.useSSL ? "rediss" : "redis";
return `${scheme}://${redisAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`;
@@ -1997,6 +2169,7 @@ const ConnectionModal: React.FC<{
params.delete("tlsInsecure");
}
}
appendSSLPathParamsForUri(params, type, values);
if (Number.isFinite(timeout) && timeout > 0) {
params.set("connectTimeoutMS", String(timeout * 1000));
params.set("serverSelectionTimeoutMS", String(timeout * 1000));
@@ -2025,20 +2198,23 @@ const ConnectionModal: React.FC<{
const mode = String(values.sslMode || "preferred")
.trim()
.toLowerCase();
if (
type === "postgres" ||
type === "kingbase" ||
type === "highgo" ||
type === "vastbase" ||
type === "opengauss"
) {
params.set("sslmode", "require");
if (isPostgresCompatibleSSLType(type)) {
params.set(
"sslmode",
mode === "skip-verify"
? "require"
: String(values.sslCAPath || "").trim()
? "verify-ca"
: "require",
);
appendSSLPathParamsForUri(params, type, values);
} else if (type === "sqlserver") {
params.set("encrypt", "true");
params.set(
"TrustServerCertificate",
mode === "skip-verify" || mode === "preferred" ? "true" : "false",
);
appendSSLPathParamsForUri(params, type, values);
} else if (type === "clickhouse") {
if (clickHouseProtocol === "http") {
if (mode === "skip-verify" || mode === "preferred") {
@@ -2050,11 +2226,9 @@ const ConnectionModal: React.FC<{
params.set("skip_verify", "true");
}
}
appendSSLPathParamsForUri(params, type, values);
} else if (type === "dameng") {
const certPath = String(values.sslCertPath || "").trim();
const keyPath = String(values.sslKeyPath || "").trim();
if (certPath) params.set("SSL_CERT_PATH", certPath);
if (keyPath) params.set("SSL_KEY_PATH", keyPath);
appendSSLPathParamsForUri(params, type, values);
} else if (type === "oracle") {
params.set("SSL", "TRUE");
params.set("SSL VERIFY", mode === "required" ? "TRUE" : "FALSE");
@@ -2065,13 +2239,7 @@ const ConnectionModal: React.FC<{
}
}
} else if (supportsSSLForType(type)) {
if (
type === "postgres" ||
type === "kingbase" ||
type === "highgo" ||
type === "vastbase" ||
type === "opengauss"
) {
if (isPostgresCompatibleSSLType(type)) {
params.set("sslmode", "disable");
} else if (type === "sqlserver") {
params.set("encrypt", "disable");
@@ -2182,6 +2350,34 @@ const ConnectionModal: React.FC<{
}
};
const handleSelectCertificateFile = async (
fieldName: "sslCAPath" | "sslCertPath" | "sslKeyPath",
certKind: "ca" | "client-cert" | "client-key",
) => {
if (selectingCertificateField) {
return;
}
try {
setSelectingCertificateField(fieldName);
const currentPath = String(form.getFieldValue(fieldName) || "").trim();
const res = await SelectCertificateFile(currentPath, certKind);
if (res?.success) {
const data = res.data || {};
const selectedPath =
typeof data === "string" ? data : String(data.path || "").trim();
if (selectedPath) {
form.setFieldValue(fieldName, selectedPath);
}
} else if (res?.message !== "已取消") {
message.error(`选择证书文件失败: ${res?.message || "未知错误"}`);
}
} catch (e: any) {
message.error(`选择证书文件失败: ${e?.message || String(e)}`);
} finally {
setSelectingCertificateField(null);
}
};
const handleSelectDatabaseFile = async () => {
if (selectingDbFile) {
return;
@@ -2317,6 +2513,7 @@ const ConnectionModal: React.FC<{
includeRedisDatabases: initialValues.includeRedisDatabases,
useSSL: !!config.useSSL,
sslMode: config.sslMode || "preferred",
sslCAPath: config.sslCAPath || "",
sslCertPath: config.sslCertPath || "",
sslKeyPath: config.sslKeyPath || "",
useSSH: config.useSSH,
@@ -3166,6 +3363,9 @@ const ConnectionModal: React.FC<{
? "disable"
: "preferred";
const effectiveUseSSL = sslCapableType && !!mergedValues.useSSL;
const sslCAPath = sslCapableType
? String(mergedValues.sslCAPath || "").trim()
: "";
const sslCertPath = sslCapableType
? String(mergedValues.sslCertPath || "").trim()
: "";
@@ -3175,6 +3375,9 @@ const ConnectionModal: React.FC<{
if (type === "dameng" && effectiveUseSSL && (!sslCertPath || !sslKeyPath)) {
throw new Error("达梦启用 SSL 时必须填写证书路径与私钥路径");
}
if (effectiveUseSSL && supportsSSLClientCertificateForType(type) && (!!sslCertPath !== !!sslKeyPath)) {
throw new Error("TLS 客户端证书与私钥路径需要同时填写");
}
let primaryHost = "localhost";
let primaryPort = defaultPort;
@@ -3369,6 +3572,7 @@ const ConnectionModal: React.FC<{
database: mergedValues.database || "",
useSSL: effectiveUseSSL,
sslMode: effectiveUseSSL ? sslMode : "disable",
sslCAPath: sslCAPath,
sslCertPath: sslCertPath,
sslKeyPath: sslKeyPath,
useSSH: !!mergedValues.useSSH,
@@ -3438,6 +3642,7 @@ const ConnectionModal: React.FC<{
database: "",
useSSL: false,
sslMode: undefined,
sslCAPath: undefined,
sslCertPath: undefined,
sslKeyPath: undefined,
useSSH: false,
@@ -3501,6 +3706,7 @@ const ConnectionModal: React.FC<{
database: "",
useSSL: false,
sslMode: "preferred",
sslCAPath: "",
sslCertPath: "",
sslKeyPath: "",
useSSH: false,
@@ -3551,6 +3757,7 @@ const ConnectionModal: React.FC<{
port: defaultPort,
useSSL: sslCapableType ? false : undefined,
sslMode: sslCapableType ? "preferred" : undefined,
sslCAPath: sslCapableType ? "" : undefined,
sslCertPath: sslCapableType ? "" : undefined,
sslKeyPath: sslCapableType ? "" : undefined,
useHttpTunnel: false,
@@ -5556,41 +5763,84 @@ const ConnectionModal: React.FC<{
],
})}
</div>
{dbType === "dameng" && (
<>
<Form.Item
name="sslCertPath"
label="客户端证书路径 (SSL_CERT_PATH)"
rules={[
{
required: true,
message: "达梦 SSL 需要证书路径",
},
]}
style={{ marginBottom: 8 }}
>
<Input
{...noAutoCapInputProps}
placeholder="例如: C:\certs\client-cert.pem"
/>
</Form.Item>
<Form.Item
name="sslKeyPath"
label="客户端私钥路径 (SSL_KEY_PATH)"
rules={[
{
required: true,
message: "达梦 SSL 需要私钥路径",
},
]}
style={{ marginBottom: 8 }}
>
<Input
{...noAutoCapInputProps}
placeholder="例如: C:\certs\client-key.pem"
/>
</Form.Item>
</>
{(supportsSSLCAPath || supportsSSLClientCertificate) && (
<div style={{ display: "grid", gap: 8, marginBottom: 12 }}>
{supportsSSLCAPath && (
<Form.Item
label={dbType === "sqlserver" ? "服务端证书/CA 路径" : "CA 证书路径"}
style={{ marginBottom: 0 }}
>
<Space.Compact style={{ width: "100%" }}>
<Form.Item name="sslCAPath" noStyle>
<Input
{...noAutoCapInputProps}
placeholder="例如: C:\certs\ca.pem"
/>
</Form.Item>
<Button
onClick={() => handleSelectCertificateFile("sslCAPath", "ca")}
loading={selectingCertificateField === "sslCAPath"}
>
...
</Button>
</Space.Compact>
</Form.Item>
)}
{supportsSSLClientCertificate && (
<>
<Form.Item
label={dbType === "dameng" ? "客户端证书路径 (SSL_CERT_PATH)" : "客户端证书路径"}
rules={[
{
required: dbType === "dameng",
message: "达梦 SSL 需要证书路径",
},
]}
style={{ marginBottom: 0 }}
>
<Space.Compact style={{ width: "100%" }}>
<Form.Item name="sslCertPath" noStyle>
<Input
{...noAutoCapInputProps}
placeholder="例如: C:\certs\client-cert.pem"
/>
</Form.Item>
<Button
onClick={() => handleSelectCertificateFile("sslCertPath", "client-cert")}
loading={selectingCertificateField === "sslCertPath"}
>
...
</Button>
</Space.Compact>
</Form.Item>
<Form.Item
label={dbType === "dameng" ? "客户端私钥路径 (SSL_KEY_PATH)" : "客户端私钥路径"}
rules={[
{
required: dbType === "dameng",
message: "达梦 SSL 需要私钥路径",
},
]}
style={{ marginBottom: 0 }}
>
<Space.Compact style={{ width: "100%" }}>
<Form.Item name="sslKeyPath" noStyle>
<Input
{...noAutoCapInputProps}
placeholder="例如: C:\certs\client-key.pem"
/>
</Form.Item>
<Button
onClick={() => handleSelectCertificateFile("sslKeyPath", "client-key")}
loading={selectingCertificateField === "sslKeyPath"}
>
...
</Button>
</Space.Compact>
</Form.Item>
</>
)}
</div>
)}
<Text type="secondary" style={{ fontSize: 12 }}>
{sslHintText}
@@ -6216,6 +6466,7 @@ const ConnectionModal: React.FC<{
user: "root",
useSSL: false,
sslMode: "preferred",
sslCAPath: "",
sslCertPath: "",
sslKeyPath: "",
useSSH: false,

View File

@@ -262,6 +262,34 @@ describe('store appearance persistence', () => {
expect(config?.port).toBe(9030);
});
it('preserves SSL certificate paths for SSL-capable saved connections', async () => {
const { useStore } = await importStore();
useStore.getState().replaceConnections([
{
id: 'postgres-ssl',
name: 'Postgres SSL',
config: {
id: 'postgres-ssl',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
useSSL: true,
sslMode: 'required',
sslCAPath: 'C:/certs/ca.pem',
sslCertPath: 'C:/certs/client-cert.pem',
sslKeyPath: 'C:/certs/client-key.pem',
},
},
]);
const config = useStore.getState().connections[0]?.config;
expect(config?.sslCAPath).toBe('C:/certs/ca.pem');
expect(config?.sslCertPath).toBe('C:/certs/client-cert.pem');
expect(config?.sslKeyPath).toBe('C:/certs/client-key.pem');
});
it('normalizes OceanBase protocol override when replacing saved connections', async () => {
const { useStore } = await importStore();

View File

@@ -529,6 +529,7 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
database: toTrimmedString(raw.database),
useSSL: sslCapable ? !!raw.useSSL : false,
sslMode: sslCapable ? sslMode : "disable",
sslCAPath: sslCapable ? toTrimmedString(raw.sslCAPath) : "",
sslCertPath: sslCapable ? toTrimmedString(raw.sslCertPath) : "",
sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : "",
useSSH: !!raw.useSSH,

View File

@@ -284,6 +284,7 @@ export interface ConnectionConfig {
database?: string;
useSSL?: boolean;
sslMode?: "preferred" | "required" | "skip-verify" | "disable";
sslCAPath?: string;
sslCertPath?: string;
sslKeyPath?: string;
useSSH?: boolean;

View File

@@ -148,6 +148,27 @@ describe('buildRpcConnectionConfig', () => {
expect(result.connectionParams).toBe('characterEncoding=utf8&useSSL=false');
});
it('preserves SSL certificate path fields for RPC calls', () => {
const result = buildRpcConnectionConfig({
id: 'conn-postgres-ssl',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
useSSL: true,
sslMode: 'required',
sslCAPath: 'C:/certs/ca.pem',
sslCertPath: 'C:/certs/client-cert.pem',
sslKeyPath: 'C:/certs/client-key.pem',
} as any);
expect(result.useSSL).toBe(true);
expect(result.sslMode).toBe('required');
expect(result.sslCAPath).toBe('C:/certs/ca.pem');
expect(result.sslCertPath).toBe('C:/certs/client-cert.pem');
expect(result.sslKeyPath).toBe('C:/certs/client-key.pem');
});
it('fills default nested config blocks needed by RPC calls', () => {
const result = buildRpcConnectionConfig({
id: 'conn-redis',

View File

@@ -242,6 +242,8 @@ export function RenameTable(arg1:connection.ConnectionConfig,arg2:string,arg3:st
export function RenameView(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function ResetWebViewZoom():Promise<connection.QueryResult>;
export function ResolveDriverDownloadDirectory(arg1:string):Promise<connection.QueryResult>;
export function ResolveDriverPackageDownloadURL(arg1:string,arg2:string):Promise<connection.QueryResult>;
@@ -256,6 +258,8 @@ export function SaveConnection(arg1:connection.SavedConnectionInput):Promise<con
export function SaveGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise<connection.GlobalProxyView>;
export function SelectCertificateFile(arg1:string,arg2:string):Promise<connection.QueryResult>;
export function SelectDataRootDirectory(arg1:string):Promise<connection.QueryResult>;
export function SelectDatabaseFile(arg1:string,arg2:string):Promise<connection.QueryResult>;

View File

@@ -474,6 +474,10 @@ export function RenameView(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['RenameView'](arg1, arg2, arg3, arg4);
}
export function ResetWebViewZoom() {
return window['go']['app']['App']['ResetWebViewZoom']();
}
export function ResolveDriverDownloadDirectory(arg1) {
return window['go']['app']['App']['ResolveDriverDownloadDirectory'](arg1);
}
@@ -502,6 +506,10 @@ export function SaveGlobalProxy(arg1) {
return window['go']['app']['App']['SaveGlobalProxy'](arg1);
}
export function SelectCertificateFile(arg1, arg2) {
return window['go']['app']['App']['SelectCertificateFile'](arg1, arg2);
}
export function SelectDataRootDirectory(arg1) {
return window['go']['app']['App']['SelectDataRootDirectory'](arg1);
}

View File

@@ -673,6 +673,7 @@ export namespace connection {
database: string;
useSSL?: boolean;
sslMode?: string;
sslCAPath?: string;
sslCertPath?: string;
sslKeyPath?: string;
useSSH: boolean;
@@ -718,6 +719,7 @@ export namespace connection {
this.database = source["database"];
this.useSSL = source["useSSL"];
this.sslMode = source["sslMode"];
this.sslCAPath = source["sslCAPath"];
this.sslCertPath = source["sslCertPath"];
this.sslKeyPath = source["sslKeyPath"];
this.useSSH = source["useSSH"];

View File

@@ -553,6 +553,61 @@ func (a *App) SelectSSHKeyFile(currentPath string) connection.QueryResult {
return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}}
}
func (a *App) SelectCertificateFile(currentPath string, certKind string) connection.QueryResult {
defaultDir := strings.TrimSpace(currentPath)
if defaultDir == "" {
if home, err := os.UserHomeDir(); err == nil {
defaultDir = home
}
}
if filepath.Ext(defaultDir) != "" {
defaultDir = filepath.Dir(defaultDir)
}
if defaultDir != "" && !filepath.IsAbs(defaultDir) {
if abs, err := filepath.Abs(defaultDir); err == nil {
defaultDir = abs
}
}
kind := strings.ToLower(strings.TrimSpace(certKind))
title := "选择 TLS 证书文件"
displayName := "证书文件"
switch kind {
case "ca":
title = "选择 CA/服务端证书文件"
case "client-cert":
title = "选择客户端证书文件"
case "client-key":
title = "选择客户端私钥文件"
displayName = "私钥文件"
}
selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{
Title: title,
DefaultDirectory: defaultDir,
Filters: []runtime.FileFilter{
{
DisplayName: displayName,
Pattern: "*.pem;*.crt;*.cer;*.cert;*.key",
},
{
DisplayName: "所有文件",
Pattern: "*",
},
},
})
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if strings.TrimSpace(selection) == "" {
return connection.QueryResult{Success: false, Message: "已取消"}
}
if abs, err := filepath.Abs(selection); err == nil {
selection = abs
}
return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}}
}
func (a *App) SelectDatabaseFile(currentPath string, driverType string) connection.QueryResult {
defaultDir := strings.TrimSpace(currentPath)
if defaultDir == "" {

View File

@@ -89,6 +89,7 @@ type ConnectionConfig struct {
Database string `json:"database"`
UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch
SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable
SSLCAPath string `json:"sslCAPath,omitempty"` // TLS root CA / server certificate path
SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng)
SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng)
UseSSH bool `json:"useSSH"`

View File

@@ -167,7 +167,7 @@ func defaultClickHousePortForScheme(scheme string) int {
}
}
func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) *clickhouse.Options {
func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) (*clickhouse.Options, error) {
connectTimeout := getConnectTimeout(config)
readTimeout := connectTimeout
if readTimeout < minClickHouseReadTimeout {
@@ -187,11 +187,15 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
DialTimeout: connectTimeout,
ReadTimeout: readTimeout,
}
if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil {
tlsConfig, err := resolveGenericTLSConfig(config)
if err != nil {
return nil, err
}
if tlsConfig != nil {
opts.TLS = tlsConfig
}
applyClickHouseConnectionParams(opts, config)
return opts
return opts, nil
}
func parseClickHouseDurationParam(raw string) (time.Duration, bool) {
@@ -549,7 +553,14 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
protocolConfig := withClickHouseProtocol(attempt, protocol)
logger.Infof("ClickHouse 连接尝试:第%d组/%d 协议=%s 地址=%s:%d SSL=%t",
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL)
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(protocolConfig))
opts, err := c.buildClickHouseOptions(protocolConfig)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败(protocol=%s): %v", idx+1, protocol.String(), err))
logger.Warnf("ClickHouse TLS 配置失败:第%d组/%d 协议=%s 地址=%s:%d SSL=%t 原因=%v",
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL, err)
continue
}
c.conn = clickhouse.OpenDB(opts)
if err := c.Ping(); err != nil {
failureMessage := clickHouseAttemptFailureMessage(protocol, err)
failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %s", idx+1, protocol.String(), failureMessage))

View File

@@ -139,7 +139,7 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) {
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
return buildMySQLCompatibleDSN(config, protocol, address, database)
}
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -55,6 +55,78 @@ func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) {
}
}
func TestPostgresDSN_AppendsSSLPathParams(t *testing.T) {
p := &PostgresDB{}
cfg := connection.ConnectionConfig{
Type: "postgres",
Host: "127.0.0.1",
Port: 5432,
User: "user",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "required",
SSLCAPath: "C:\\certs\\ca.pem",
SSLCertPath: "C:\\certs\\client-cert.pem",
SSLKeyPath: "C:\\certs\\client-key.pem",
}
dsn := p.getDSN(cfg)
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("parse postgres dsn: %v", err)
}
query := parsed.Query()
if got := query.Get("sslmode"); got != "verify-ca" {
t.Fatalf("sslmode = %q, want verify-ca", got)
}
if got := query.Get("sslrootcert"); got != cfg.SSLCAPath {
t.Fatalf("sslrootcert = %q, want %q", got, cfg.SSLCAPath)
}
if got := query.Get("sslcert"); got != cfg.SSLCertPath {
t.Fatalf("sslcert = %q, want %q", got, cfg.SSLCertPath)
}
if got := query.Get("sslkey"); got != cfg.SSLKeyPath {
t.Fatalf("sslkey = %q, want %q", got, cfg.SSLKeyPath)
}
}
func TestPostgresDSN_SkipVerifyOmitsSSLRootCert(t *testing.T) {
p := &PostgresDB{}
cfg := connection.ConnectionConfig{
Type: "postgres",
Host: "127.0.0.1",
Port: 5432,
User: "user",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "skip-verify",
SSLCAPath: "C:\\certs\\ca.pem",
SSLCertPath: "C:\\certs\\client-cert.pem",
SSLKeyPath: "C:\\certs\\client-key.pem",
}
dsn := p.getDSN(cfg)
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("parse postgres dsn: %v", err)
}
query := parsed.Query()
if got := query.Get("sslmode"); got != "require" {
t.Fatalf("sslmode = %q, want require", got)
}
if got := query.Get("sslrootcert"); got != "" {
t.Fatalf("sslrootcert should be omitted for skip-verify, got %q", got)
}
if got := query.Get("sslcert"); got != cfg.SSLCertPath {
t.Fatalf("sslcert = %q, want %q", got, cfg.SSLCertPath)
}
if got := query.Get("sslkey"); got != cfg.SSLKeyPath {
t.Fatalf("sslkey = %q, want %q", got, cfg.SSLKeyPath)
}
}
func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
p := &PostgresDB{}
cfg := connection.ConnectionConfig{
@@ -109,6 +181,65 @@ func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
}
}
func TestMySQLDSN_UsesCustomTLSConfigWhenCertificatePathsAreConfigured(t *testing.T) {
m := &MySQLDB{}
cfg := connection.ConnectionConfig{
Type: "mysql",
Host: "127.0.0.1",
Port: 3306,
User: "root",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "required",
SSLCAPath: "../../third_party/highgo-pq/certs/root.crt",
SSLCertPath: "../../third_party/highgo-pq/certs/postgresql.crt",
SSLKeyPath: "../../third_party/highgo-pq/certs/postgresql.key",
}
dsn, err := m.getDSN(cfg)
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
if strings.Contains(dsn, "tls=true") {
t.Fatalf("dsn 应使用自定义 TLS 配置名而不是 tls=true%s", dsn)
}
if !strings.Contains(dsn, "tls=gonavi-") {
t.Fatalf("dsn 缺少自定义 TLS 配置名:%s", dsn)
}
if strings.Contains(dsn, "allowFallbackToPlaintext=true") {
t.Fatalf("required 模式不应启用明文回退:%s", dsn)
}
}
func TestMySQLDSN_PreservesPreferredFallbackWithCustomTLSConfig(t *testing.T) {
m := &MySQLDB{}
cfg := connection.ConnectionConfig{
Type: "mysql",
Host: "127.0.0.1",
Port: 3306,
User: "root",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "preferred",
SSLCAPath: "../../third_party/highgo-pq/certs/root.crt",
SSLCertPath: "../../third_party/highgo-pq/certs/postgresql.crt",
SSLKeyPath: "../../third_party/highgo-pq/certs/postgresql.key",
}
dsn, err := m.getDSN(cfg)
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
if !strings.Contains(dsn, "tls=gonavi-") {
t.Fatalf("dsn 缺少自定义 TLS 配置名:%s", dsn)
}
if !strings.Contains(dsn, "allowFallbackToPlaintext=true") {
t.Fatalf("preferred 自定义 TLS 配置应保留明文回退:%s", dsn)
}
}
func TestOracleDSN_EscapesUserAndPassword(t *testing.T) {
o := &OracleDB{}
cfg := connection.ConnectionConfig{
@@ -380,6 +511,30 @@ func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
}
}
func TestSQLServerDSN_AppendsCertificateParam(t *testing.T) {
s := &SqlServerDB{}
cfg := connection.ConnectionConfig{
Type: "sqlserver",
Host: "127.0.0.1",
Port: 1433,
User: "sa",
Password: "pass",
Database: "master",
UseSSL: true,
SSLMode: "required",
SSLCAPath: "C:\\certs\\sqlserver-ca.pem",
}
dsn := s.getDSN(cfg)
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("parse sqlserver dsn: %v", err)
}
if got := parsed.Query().Get("certificate"); got != cfg.SSLCAPath {
t.Fatalf("certificate = %q, want %q", got, cfg.SSLCAPath)
}
}
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
@@ -392,7 +547,10 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
Timeout: 15,
})
opts := c.buildClickHouseOptions(cfg)
opts, err := c.buildClickHouseOptions(cfg)
if err != nil {
t.Fatalf("buildClickHouseOptions failed: %v", err)
}
if opts == nil {
t.Fatal("options 为空")
}
@@ -438,7 +596,10 @@ func TestClickHouseOptions_MergesConnectionParamsIntoOptionsAndSettings(t *testi
ConnectionParams: "max_execution_time=60&compress=lz4&read_timeout=10s",
})
opts := c.buildClickHouseOptions(cfg)
opts, err := c.buildClickHouseOptions(cfg)
if err != nil {
t.Fatalf("buildClickHouseOptions failed: %v", err)
}
if opts == nil {
t.Fatal("options 为空")
}
@@ -465,7 +626,10 @@ func TestClickHouseOptions_ReadTimeoutUsesLargerConfiguredTimeout(t *testing.T)
Timeout: 900,
})
opts := c.buildClickHouseOptions(cfg)
opts, err := c.buildClickHouseOptions(cfg)
if err != nil {
t.Fatalf("buildClickHouseOptions failed: %v", err)
}
if opts == nil {
t.Fatal("options 为空")
}

View File

@@ -43,6 +43,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
applyPostgresSSLPathParams(q, config)
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfigWithAllowlist(q, config, highGoConnectionParamNames, "postgres", "postgresql", "highgo")
u.RawQuery = q.Encode()

View File

@@ -70,10 +70,11 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
params.Set("password", config.Password)
params.Set("dbname", config.Database)
params.Set("sslmode", resolvePostgresSSLMode(config))
applyPostgresSSLPathParams(params, config)
params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfigWithAllowlist(params, config, kingbaseConnectionParamNames, "kingbase")
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"}
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "sslrootcert", "sslcert", "sslkey", "connect_timeout"}
seen := make(map[string]struct{}, len(params))
parts := make([]string, 0, len(params))
for _, key := range preferred {

View File

@@ -36,7 +36,7 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) {
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
return buildMySQLCompatibleDSN(config, protocol, address, database)
}
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {

View File

@@ -4,7 +4,6 @@ package db
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
@@ -414,12 +413,15 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
}
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig)
if tlsEnabled {
clientOpts.SetTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: tlsInsecure,
})
tlsConfig, tlsErr := resolveGenericTLSConfig(attemptConfig)
if tlsErr != nil {
detail := fmt.Sprintf("%s %sTLS 配置失败: %v", sslLabel, authLabel, tlsErr)
errorDetails = append(errorDetails, detail)
logger.Warnf("MongoDB TLS 配置失败:%d/%d 模式=%s 凭据=%s 错误=%v", attemptNo, totalAttempts, sslLabel, authLabel, tlsErr)
continue
}
if tlsConfig != nil {
clientOpts.SetTLSConfig(tlsConfig)
}
if attemptConfig.UseProxy {
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})

View File

@@ -4,7 +4,6 @@ package db
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
@@ -415,12 +414,15 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error {
}
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig)
if tlsEnabled {
clientOpts.SetTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: tlsInsecure,
})
tlsConfig, tlsErr := resolveGenericTLSConfig(attemptConfig)
if tlsErr != nil {
detail := fmt.Sprintf("%s %sTLS 配置失败: %v", sslLabel, authLabel, tlsErr)
errorDetails = append(errorDetails, detail)
logger.Warnf("MongoDB TLS 配置失败:%d/%d 模式=%s 凭据=%s 错误=%v", attemptNo, totalAttempts, sslLabel, authLabel, tlsErr)
continue
}
if tlsConfig != nil {
clientOpts.SetTLSConfig(tlsConfig)
}
if attemptConfig.UseProxy {
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})

View File

@@ -17,7 +17,7 @@ import (
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/go-sql-driver/mysql"
mysql "github.com/go-sql-driver/mysql"
)
type MySQLDB struct {
@@ -115,19 +115,19 @@ var mysqlSupportedDriverParamNames = map[string]string{
// OceanBase Oracle 租户 MySQL wire 路径用它注入 OBClient 私有 capability attribute
// 普通 mysql/mariadb 用户也能在此声明 program_name 等元数据。
"connectionattributes": "connectionAttributes",
"interpolateparams": "interpolateParams",
"loc": "loc",
"maxallowedpacket": "maxAllowedPacket",
"multistatements": "multiStatements",
"parsetime": "parseTime",
"readtimeout": "readTimeout",
"rejectreadonly": "rejectReadOnly",
"serverpubkey": "serverPubKey",
"sql_mode": "sql_mode",
"timetruncate": "timeTruncate",
"timeout": "timeout",
"tls": "tls",
"writetimeout": "writeTimeout",
"interpolateparams": "interpolateParams",
"loc": "loc",
"maxallowedpacket": "maxAllowedPacket",
"multistatements": "multiStatements",
"parsetime": "parseTime",
"readtimeout": "readTimeout",
"rejectreadonly": "rejectReadOnly",
"serverpubkey": "serverPubKey",
"sql_mode": "sql_mode",
"timetruncate": "timeTruncate",
"timeout": "timeout",
"tls": "tls",
"writetimeout": "writeTimeout",
}
var mysqlBoolDriverParamNames = map[string]struct{}{
@@ -274,15 +274,40 @@ func mergeMySQLConnectionParams(params url.Values, values url.Values) {
}
}
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string {
func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, error) {
mode := resolveMySQLTLSMode(config)
if mode == "false" || !hasTLSCertificatePaths(config) {
return mode, false, nil
}
tlsConfig, err := resolveGenericTLSConfig(config)
if err != nil {
return "", false, err
}
if tlsConfig == nil {
return mode, false, nil
}
name := mysqlTLSConfigName(config)
if err := mysql.RegisterTLSConfig(name, tlsConfig); err != nil && !strings.Contains(strings.ToLower(err.Error()), "already registered") {
return "", false, fmt.Errorf("注册 MySQL TLS 证书配置失败:%w", err)
}
return name, normalizeSSLModeValue(config.SSLMode) == sslModePreferred, nil
}
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
tlsMode, allowFallbackToPlaintext, err := resolveMySQLTLSParam(config)
if err != nil {
return "", err
}
params := url.Values{}
params.Set("charset", "utf8mb4")
params.Set("parseTime", "True")
params.Set("loc", "Local")
params.Set("timeout", fmt.Sprintf("%ds", timeout))
params.Set("tls", tlsMode)
if allowFallbackToPlaintext {
params.Set("allowFallbackToPlaintext", "true")
}
params.Set("multiStatements", "true")
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok {
mergeMySQLConnectionParams(params, parsed.Query())
@@ -291,7 +316,7 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?%s",
config.User, config.Password, protocol, address, database, params.Encode(),
)
), nil
}
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
@@ -502,7 +527,7 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
return buildMySQLCompatibleDSN(config, protocol, address, database)
}
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -185,7 +185,7 @@ func (o *OceanBaseDB) getDSN(config connection.ConnectionConfig) (string, error)
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
return buildMySQLCompatibleDSN(config, protocol, address, database)
}
func normalizeOceanBaseProtocol(raw string) string {

View File

@@ -63,6 +63,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
applyPostgresSSLPathParams(q, config)
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "opengauss")
u.RawQuery = q.Encode()

View File

@@ -50,6 +50,9 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
q.Set("encrypt", encrypt)
q.Set("trustservercertificate", trustServerCertificate)
if strings.TrimSpace(config.SSLCAPath) != "" {
q.Set("certificate", strings.TrimSpace(config.SSLCAPath))
}
mergeConnectionParamsFromConfigWithAllowlist(q, config, sqlServerConnectionParamNames, "sqlserver")
u.RawQuery = q.Encode()

View File

@@ -1,10 +1,13 @@
package db
import (
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"strings"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/tlsconfig"
)
const (
@@ -61,15 +64,37 @@ func resolveMySQLTLSMode(config connection.ConnectionConfig) string {
}
}
func hasTLSCertificatePaths(config connection.ConnectionConfig) bool {
return strings.TrimSpace(config.SSLCAPath) != "" ||
strings.TrimSpace(config.SSLCertPath) != "" ||
strings.TrimSpace(config.SSLKeyPath) != ""
}
func mysqlTLSConfigName(config connection.ConnectionConfig) string {
sum := sha256.Sum256([]byte(strings.Join([]string{
normalizedSSLMode(config),
strings.TrimSpace(config.SSLCAPath),
strings.TrimSpace(config.SSLCertPath),
strings.TrimSpace(config.SSLKeyPath),
}, "\x00")))
return "gonavi-" + hex.EncodeToString(sum[:8])
}
func resolvePostgresSSLMode(config connection.ConnectionConfig) string {
switch normalizedSSLMode(config) {
case sslModeDisable:
return "disable"
case sslModeRequired:
if strings.TrimSpace(config.SSLCAPath) != "" {
return "verify-ca"
}
return "require"
case sslModeSkipVerify:
return "require"
default:
if strings.TrimSpace(config.SSLCAPath) != "" {
return "verify-ca"
}
return "require"
}
}
@@ -87,17 +112,47 @@ func resolveSQLServerTLSSettings(config connection.ConnectionConfig) (encrypt st
}
}
func resolveGenericTLSConfig(config connection.ConnectionConfig) *tls.Config {
func applyPostgresSSLPathParams(params interface{ Set(string, string) }, config connection.ConnectionConfig) {
mode := normalizedSSLMode(config)
if mode != sslModeDisable && mode != sslModeSkipVerify && strings.TrimSpace(config.SSLCAPath) != "" {
params.Set("sslrootcert", strings.TrimSpace(config.SSLCAPath))
}
if mode != sslModeDisable && strings.TrimSpace(config.SSLCertPath) != "" {
params.Set("sslcert", strings.TrimSpace(config.SSLCertPath))
}
if mode != sslModeDisable && strings.TrimSpace(config.SSLKeyPath) != "" {
params.Set("sslkey", strings.TrimSpace(config.SSLKeyPath))
}
}
func resolveGenericTLSConfig(config connection.ConnectionConfig) (*tls.Config, error) {
switch normalizedSSLMode(config) {
case sslModeDisable:
return nil
return nil, nil
case sslModeRequired:
return &tls.Config{MinVersion: tls.VersionTLS12}
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
Enabled: true,
CAPath: config.SSLCAPath,
CertPath: config.SSLCertPath,
KeyPath: config.SSLKeyPath,
})
case sslModeSkipVerify:
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
Enabled: true,
InsecureSkipVerify: true,
CAPath: config.SSLCAPath,
CertPath: config.SSLCertPath,
KeyPath: config.SSLKeyPath,
})
default:
// Preferred: 先尝试 TLS为提升兼容性默认跳过证书校验失败时由调用方按需回退明文。
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
Enabled: true,
InsecureSkipVerify: true,
CAPath: config.SSLCAPath,
CertPath: config.SSLCertPath,
KeyPath: config.SSLKeyPath,
})
}
}

View File

@@ -214,7 +214,7 @@ func (s *StarRocksDB) getDSN(config connection.ConnectionConfig) (string, error)
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
return buildMySQLCompatibleDSN(config, protocol, address, database)
}
func resolveStarRocksCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -42,6 +42,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
applyPostgresSSLPathParams(q, config)
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "vastbase")
u.RawQuery = q.Encode()

View File

@@ -277,7 +277,10 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
var failures []string
for idx, attempt := range attempts {
var tlsConfig *tls.Config
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
if cfg, err := resolveRedisTLSConfig(attempt); err != nil {
failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败: %v", idx+1, err))
continue
} else if cfg != nil {
if host, _, err := net.SplitHostPort(seedAddrs[0]); err == nil && host != "" {
cfg.ServerName = host
}
@@ -332,7 +335,10 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
var failures []string
for idx, attempt := range attempts {
var tlsConfig *tls.Config
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
if cfg, err := resolveRedisTLSConfig(attempt); err != nil {
failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败: %v", idx+1, err))
continue
} else if cfg != nil {
if host, _, err := net.SplitHostPort(addr); err == nil && host != "" {
cfg.ServerName = host
}

View File

@@ -5,6 +5,7 @@ import (
"strings"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/tlsconfig"
)
func normalizeRedisSSLMode(raw string) string {
@@ -41,15 +42,32 @@ func withRedisSSLDisabled(config connection.ConnectionConfig) connection.Connect
return next
}
func resolveRedisTLSConfig(config connection.ConnectionConfig) *tls.Config {
func resolveRedisTLSConfig(config connection.ConnectionConfig) (*tls.Config, error) {
switch redisSSLMode(config) {
case "disable":
return nil
return nil, nil
case "required":
return &tls.Config{MinVersion: tls.VersionTLS12}
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
Enabled: true,
CAPath: config.SSLCAPath,
CertPath: config.SSLCertPath,
KeyPath: config.SSLKeyPath,
})
case "skip-verify":
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
Enabled: true,
InsecureSkipVerify: true,
CAPath: config.SSLCAPath,
CertPath: config.SSLCertPath,
KeyPath: config.SSLKeyPath,
})
default:
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
Enabled: true,
InsecureSkipVerify: true,
CAPath: config.SSLCAPath,
CertPath: config.SSLCertPath,
KeyPath: config.SSLKeyPath,
})
}
}

View File

@@ -0,0 +1,62 @@
package tlsconfig
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"strings"
)
type ClientConfigOptions struct {
Enabled bool
InsecureSkipVerify bool
CAPath string
CertPath string
KeyPath string
}
func BuildClientConfig(options ClientConfigOptions) (*tls.Config, error) {
if !options.Enabled {
return nil, nil
}
cfg := &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: options.InsecureSkipVerify,
}
caPath := strings.TrimSpace(options.CAPath)
if caPath != "" {
pemBytes, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("读取 TLS CA 证书失败(%s%w", caPath, err)
}
pool := x509.NewCertPool()
if ok := pool.AppendCertsFromPEM(pemBytes); !ok {
certs, err := x509.ParseCertificates(pemBytes)
if err != nil || len(certs) == 0 {
return nil, fmt.Errorf("TLS CA 证书不是有效的 PEM/DER 文件:%s", caPath)
}
for _, cert := range certs {
pool.AddCert(cert)
}
}
cfg.RootCAs = pool
}
certPath := strings.TrimSpace(options.CertPath)
keyPath := strings.TrimSpace(options.KeyPath)
if (certPath == "") != (keyPath == "") {
return nil, fmt.Errorf("TLS 客户端证书和私钥需要同时配置")
}
if certPath != "" {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("加载 TLS 客户端证书失败cert=%s key=%s%w", certPath, keyPath, err)
}
cfg.Certificates = []tls.Certificate{cert}
}
return cfg, nil
}

View File

@@ -0,0 +1,126 @@
package tlsconfig
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
)
func TestBuildClientConfigLoadsCAAndClientCertificate(t *testing.T) {
dir := t.TempDir()
certPath, keyPath, _ := writeSelfSignedCertificate(t, dir)
cfg, err := BuildClientConfig(ClientConfigOptions{
Enabled: true,
CAPath: certPath,
CertPath: certPath,
KeyPath: keyPath,
})
if err != nil {
t.Fatalf("BuildClientConfig failed: %v", err)
}
if cfg == nil {
t.Fatal("config is nil")
}
if cfg.RootCAs == nil {
t.Fatal("RootCAs is nil")
}
if len(cfg.Certificates) != 1 {
t.Fatalf("Certificates length = %d, want 1", len(cfg.Certificates))
}
}
func TestBuildClientConfigLoadsDERCA(t *testing.T) {
dir := t.TempDir()
_, _, derBytes := writeSelfSignedCertificate(t, dir)
derPath := filepath.Join(dir, "ca.cer")
if err := os.WriteFile(derPath, derBytes, 0600); err != nil {
t.Fatalf("write der cert: %v", err)
}
cfg, err := BuildClientConfig(ClientConfigOptions{
Enabled: true,
CAPath: derPath,
})
if err != nil {
t.Fatalf("BuildClientConfig failed: %v", err)
}
if cfg == nil {
t.Fatal("config is nil")
}
if cfg.RootCAs == nil {
t.Fatal("RootCAs is nil")
}
}
func TestBuildClientConfigRequiresClientCertificateAndKeyTogether(t *testing.T) {
_, err := BuildClientConfig(ClientConfigOptions{
Enabled: true,
CertPath: "client.pem",
})
if err == nil {
t.Fatal("expected error")
}
}
func writeSelfSignedCertificate(t *testing.T, dir string) (string, string, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate key: %v", err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "GoNavi Test",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("create certificate: %v", err)
}
certPath := filepath.Join(dir, "cert.pem")
certFile, err := os.Create(certPath)
if err != nil {
t.Fatalf("create cert file: %v", err)
}
if err := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
_ = certFile.Close()
t.Fatalf("write cert: %v", err)
}
if err := certFile.Close(); err != nil {
t.Fatalf("close cert file: %v", err)
}
keyPath := filepath.Join(dir, "key.pem")
keyFile, err := os.Create(keyPath)
if err != nil {
t.Fatalf("create key file: %v", err)
}
keyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
if err := pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}); err != nil {
_ = keyFile.Close()
t.Fatalf("write key: %v", err)
}
if err := keyFile.Close(); err != nil {
t.Fatalf("close key file: %v", err)
}
return certPath, keyPath, derBytes
}