mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-22 08:50:17 +08:00
✨ feat(connection): 支持连接 SSL 证书文件配置
- 新增 CA 证书、客户端证书和私钥路径配置 - 为 MySQL、PostgreSQL、ClickHouse、MongoDB、Redis 等连接接入 TLS 证书 - 修正 SSL 模式下证书校验、明文回退和 DER 证书兼容问题 - 补充证书路径保存、RPC 传递和 DSN 生成回归测试 Refs #463
This commit is contained in:
@@ -92,6 +92,7 @@ import {
|
||||
TestConnection,
|
||||
RedisConnect,
|
||||
SelectDatabaseFile,
|
||||
SelectCertificateFile,
|
||||
SelectSSHKeyFile,
|
||||
TestJVMConnection,
|
||||
} from "../../wailsjs/go/app/App";
|
||||
@@ -283,6 +284,69 @@ const supportsSSLForType = (type: string) =>
|
||||
.toLowerCase(),
|
||||
);
|
||||
|
||||
const sslCAPathSupportedTypes = new Set([
|
||||
"mysql",
|
||||
"mariadb",
|
||||
"oceanbase",
|
||||
"diros",
|
||||
"starrocks",
|
||||
"sphinx",
|
||||
"clickhouse",
|
||||
"postgres",
|
||||
"sqlserver",
|
||||
"kingbase",
|
||||
"highgo",
|
||||
"vastbase",
|
||||
"opengauss",
|
||||
"mongodb",
|
||||
"redis",
|
||||
]);
|
||||
|
||||
const sslClientCertificateSupportedTypes = new Set([
|
||||
"mysql",
|
||||
"mariadb",
|
||||
"oceanbase",
|
||||
"diros",
|
||||
"starrocks",
|
||||
"sphinx",
|
||||
"dameng",
|
||||
"clickhouse",
|
||||
"postgres",
|
||||
"kingbase",
|
||||
"highgo",
|
||||
"vastbase",
|
||||
"opengauss",
|
||||
"mongodb",
|
||||
"redis",
|
||||
]);
|
||||
|
||||
const supportsSSLCAPathForType = (type: string) =>
|
||||
sslCAPathSupportedTypes.has(
|
||||
String(type || "")
|
||||
.trim()
|
||||
.toLowerCase(),
|
||||
);
|
||||
|
||||
const supportsSSLClientCertificateForType = (type: string) =>
|
||||
sslClientCertificateSupportedTypes.has(
|
||||
String(type || "")
|
||||
.trim()
|
||||
.toLowerCase(),
|
||||
);
|
||||
|
||||
const isPostgresCompatibleSSLType = (type: string) =>
|
||||
[
|
||||
"postgres",
|
||||
"kingbase",
|
||||
"highgo",
|
||||
"vastbase",
|
||||
"opengauss",
|
||||
].includes(
|
||||
String(type || "")
|
||||
.trim()
|
||||
.toLowerCase(),
|
||||
);
|
||||
|
||||
const isFileDatabaseType = (type: string) =>
|
||||
type === "sqlite" || type === "duckdb";
|
||||
|
||||
@@ -394,6 +458,9 @@ const ConnectionModal: React.FC<{
|
||||
const [driverStatusLoaded, setDriverStatusLoaded] = useState(false);
|
||||
const [selectingDbFile, setSelectingDbFile] = useState(false);
|
||||
const [selectingSSHKey, setSelectingSSHKey] = useState(false);
|
||||
const [selectingCertificateField, setSelectingCertificateField] = useState<
|
||||
"sslCAPath" | "sslCertPath" | "sslKeyPath" | null
|
||||
>(null);
|
||||
const [clearSecrets, setClearSecrets] = useState<ConnectionSecretClearState>(
|
||||
createEmptyConnectionSecretClearState,
|
||||
);
|
||||
@@ -445,17 +512,24 @@ const ConnectionModal: React.FC<{
|
||||
const isMySQLLike = isMySQLCompatibleType(dbType) && !isOceanBaseOracle;
|
||||
const supportsConnectionParams = supportsConnectionParamsForType(dbType);
|
||||
const isSSLType = supportsSSLForType(dbType);
|
||||
const supportsSSLCAPath = supportsSSLCAPathForType(dbType);
|
||||
const supportsSSLClientCertificate =
|
||||
supportsSSLClientCertificateForType(dbType);
|
||||
const sslHintText = isMySQLLike
|
||||
? "当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL;本地自签证书场景可先用 Preferred 或 Skip Verify。"
|
||||
? "MySQL 兼容数据源支持 CA 证书、客户端证书与私钥;本地自签证书场景可先用 Preferred 或 Skip Verify。"
|
||||
: isOceanBaseOracle
|
||||
? "OceanBase Oracle 租户使用 Oracle 协议连接,SSL 参数按 Oracle 驱动规则传递。"
|
||||
? "OceanBase Oracle 租户使用 Oracle 协议连接;如需 Wallet,请在高级参数中配置 Oracle 驱动参数。"
|
||||
: dbType === "dameng"
|
||||
? "达梦驱动启用 SSL 需要客户端证书与私钥路径(sslCertPath / sslKeyPath)。"
|
||||
: dbType === "sqlserver"
|
||||
? "SQL Server 推荐在生产环境使用 Required,并关闭 TrustServerCertificate。"
|
||||
? "SQL Server 可配置服务端证书/CA 文件;生产环境建议使用 Required,并关闭 TrustServerCertificate。"
|
||||
: dbType === "mongodb"
|
||||
? "MongoDB 可通过 TLS 保护连接,证书校验异常时可先用 Skip Verify 验证连通性。"
|
||||
: "建议优先使用 Required;仅在测试环境或自签证书场景使用 Skip Verify。";
|
||||
? "MongoDB 支持 CA 证书、客户端证书与私钥;证书校验异常时可先用 Skip Verify 验证连通性。"
|
||||
: dbType === "oracle"
|
||||
? "Oracle PEM 证书请优先使用 Wallet 并在高级参数中配置 WALLET;这里仅控制 SSL 开关与校验策略。"
|
||||
: dbType === "tdengine"
|
||||
? "TDengine 当前仅配置 WSS 与校验策略;证书文件请通过服务端信任链处理。"
|
||||
: "支持的驱动可配置 CA 证书、客户端证书与私钥;仅在测试环境或自签证书场景使用 Skip Verify。";
|
||||
|
||||
const getSectionBg = (darkHex: string) => {
|
||||
if (!darkMode) {
|
||||
@@ -1339,10 +1413,102 @@ const ConnectionModal: React.FC<{
|
||||
clickHouseProtocol: "http",
|
||||
useSSL: isHttps,
|
||||
sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable",
|
||||
...extractSSLPathValuesFromParams(parsed.params, "clickhouse"),
|
||||
connectionParams: serializeConnectionParams(parsed.params),
|
||||
};
|
||||
};
|
||||
|
||||
const firstConnectionParamValue = (
|
||||
params: URLSearchParams,
|
||||
names: string[],
|
||||
): string => {
|
||||
for (const name of names) {
|
||||
const value = String(params.get(name) || "").trim();
|
||||
if (value) return value;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
|
||||
const extractSSLPathValuesFromParams = (
|
||||
params: URLSearchParams,
|
||||
type: string,
|
||||
): Record<string, string> => {
|
||||
const caPath = firstConnectionParamValue(params, [
|
||||
"sslCAPath",
|
||||
"ssl_ca_path",
|
||||
"sslrootcert",
|
||||
"sslRootCert",
|
||||
"tlsCAFile",
|
||||
"caFile",
|
||||
"certificate",
|
||||
"servercertificate",
|
||||
"serverCertificate",
|
||||
]);
|
||||
const certPath = firstConnectionParamValue(params, [
|
||||
"sslCertPath",
|
||||
"ssl_cert_path",
|
||||
"SSL_CERT_PATH",
|
||||
"sslcert",
|
||||
"sslCert",
|
||||
"tlsCertificateFile",
|
||||
]);
|
||||
const keyPath = firstConnectionParamValue(params, [
|
||||
"sslKeyPath",
|
||||
"ssl_key_path",
|
||||
"SSL_KEY_PATH",
|
||||
"sslkey",
|
||||
"sslKey",
|
||||
"tlsKeyFile",
|
||||
]);
|
||||
return {
|
||||
...(supportsSSLCAPathForType(type) && caPath ? { sslCAPath: caPath } : {}),
|
||||
...(supportsSSLClientCertificateForType(type) && certPath ? { sslCertPath: certPath } : {}),
|
||||
...(supportsSSLClientCertificateForType(type) && keyPath ? { sslKeyPath: keyPath } : {}),
|
||||
};
|
||||
};
|
||||
|
||||
const appendSSLPathParamsForUri = (
|
||||
params: URLSearchParams,
|
||||
type: string,
|
||||
values: Record<string, any>,
|
||||
) => {
|
||||
const caPath = String(values.sslCAPath || "").trim();
|
||||
const certPath = String(values.sslCertPath || "").trim();
|
||||
const keyPath = String(values.sslKeyPath || "").trim();
|
||||
const mode = String(values.sslMode || "preferred")
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
if (supportsSSLCAPathForType(type) && caPath) {
|
||||
if (isPostgresCompatibleSSLType(type)) {
|
||||
if (mode !== "skip-verify" && mode !== "disable") {
|
||||
params.set("sslrootcert", caPath);
|
||||
}
|
||||
} else if (type === "sqlserver") {
|
||||
params.set("certificate", caPath);
|
||||
} else {
|
||||
params.set("sslCAPath", caPath);
|
||||
}
|
||||
}
|
||||
if (supportsSSLClientCertificateForType(type) && certPath) {
|
||||
if (type === "dameng") {
|
||||
params.set("SSL_CERT_PATH", certPath);
|
||||
} else if (isPostgresCompatibleSSLType(type)) {
|
||||
params.set("sslcert", certPath);
|
||||
} else {
|
||||
params.set("sslCertPath", certPath);
|
||||
}
|
||||
}
|
||||
if (supportsSSLClientCertificateForType(type) && keyPath) {
|
||||
if (type === "dameng") {
|
||||
params.set("SSL_KEY_PATH", keyPath);
|
||||
} else if (isPostgresCompatibleSSLType(type)) {
|
||||
params.set("sslkey", keyPath);
|
||||
} else {
|
||||
params.set("sslKeyPath", keyPath);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const parseUriToValues = (
|
||||
uriText: string,
|
||||
type: string,
|
||||
@@ -1419,6 +1585,7 @@ const ConnectionModal: React.FC<{
|
||||
database: parsed.database || "",
|
||||
useSSL: sslMode !== "disable",
|
||||
sslMode,
|
||||
...extractSSLPathValuesFromParams(parsed.params, type),
|
||||
oceanBaseProtocol: parsedOceanBaseProtocol,
|
||||
mysqlTopology:
|
||||
parsedOceanBaseProtocol === "oracle"
|
||||
@@ -1491,6 +1658,7 @@ const ConnectionModal: React.FC<{
|
||||
? "skip-verify"
|
||||
: "required"
|
||||
: "disable",
|
||||
...extractSSLPathValuesFromParams(parsed.params, type),
|
||||
redisTopology:
|
||||
hostList.length > 1 || topologyParam === "cluster"
|
||||
? "cluster"
|
||||
@@ -1564,6 +1732,7 @@ const ConnectionModal: React.FC<{
|
||||
? "skip-verify"
|
||||
: "required"
|
||||
: "disable",
|
||||
...extractSSLPathValuesFromParams(parsed.params, type),
|
||||
mongoTopology:
|
||||
hostList.length > 1 || !!parsed.params.get("replicaSet")
|
||||
? "replica"
|
||||
@@ -1616,6 +1785,7 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
|
||||
if (supportsSSLForType(type)) {
|
||||
Object.assign(parsedValues, extractSSLPathValuesFromParams(parsed.params, type));
|
||||
const normalizeBool = (raw: unknown) => {
|
||||
const text = String(raw ?? "")
|
||||
.trim()
|
||||
@@ -1891,6 +2061,7 @@ const ConnectionModal: React.FC<{
|
||||
params.set("tls", "preferred");
|
||||
}
|
||||
}
|
||||
appendSSLPathParamsForUri(params, type, values);
|
||||
if (Number.isFinite(timeout) && timeout > 0) {
|
||||
params.set("timeout", String(timeout));
|
||||
}
|
||||
@@ -1939,6 +2110,7 @@ const ConnectionModal: React.FC<{
|
||||
params.set("skip_verify", "true");
|
||||
}
|
||||
}
|
||||
appendSSLPathParamsForUri(params, type, values);
|
||||
const query = params.toString();
|
||||
const scheme = values.useSSL ? "rediss" : "redis";
|
||||
return `${scheme}://${redisAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`;
|
||||
@@ -1997,6 +2169,7 @@ const ConnectionModal: React.FC<{
|
||||
params.delete("tlsInsecure");
|
||||
}
|
||||
}
|
||||
appendSSLPathParamsForUri(params, type, values);
|
||||
if (Number.isFinite(timeout) && timeout > 0) {
|
||||
params.set("connectTimeoutMS", String(timeout * 1000));
|
||||
params.set("serverSelectionTimeoutMS", String(timeout * 1000));
|
||||
@@ -2025,20 +2198,23 @@ const ConnectionModal: React.FC<{
|
||||
const mode = String(values.sslMode || "preferred")
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
if (
|
||||
type === "postgres" ||
|
||||
type === "kingbase" ||
|
||||
type === "highgo" ||
|
||||
type === "vastbase" ||
|
||||
type === "opengauss"
|
||||
) {
|
||||
params.set("sslmode", "require");
|
||||
if (isPostgresCompatibleSSLType(type)) {
|
||||
params.set(
|
||||
"sslmode",
|
||||
mode === "skip-verify"
|
||||
? "require"
|
||||
: String(values.sslCAPath || "").trim()
|
||||
? "verify-ca"
|
||||
: "require",
|
||||
);
|
||||
appendSSLPathParamsForUri(params, type, values);
|
||||
} else if (type === "sqlserver") {
|
||||
params.set("encrypt", "true");
|
||||
params.set(
|
||||
"TrustServerCertificate",
|
||||
mode === "skip-verify" || mode === "preferred" ? "true" : "false",
|
||||
);
|
||||
appendSSLPathParamsForUri(params, type, values);
|
||||
} else if (type === "clickhouse") {
|
||||
if (clickHouseProtocol === "http") {
|
||||
if (mode === "skip-verify" || mode === "preferred") {
|
||||
@@ -2050,11 +2226,9 @@ const ConnectionModal: React.FC<{
|
||||
params.set("skip_verify", "true");
|
||||
}
|
||||
}
|
||||
appendSSLPathParamsForUri(params, type, values);
|
||||
} else if (type === "dameng") {
|
||||
const certPath = String(values.sslCertPath || "").trim();
|
||||
const keyPath = String(values.sslKeyPath || "").trim();
|
||||
if (certPath) params.set("SSL_CERT_PATH", certPath);
|
||||
if (keyPath) params.set("SSL_KEY_PATH", keyPath);
|
||||
appendSSLPathParamsForUri(params, type, values);
|
||||
} else if (type === "oracle") {
|
||||
params.set("SSL", "TRUE");
|
||||
params.set("SSL VERIFY", mode === "required" ? "TRUE" : "FALSE");
|
||||
@@ -2065,13 +2239,7 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
}
|
||||
} else if (supportsSSLForType(type)) {
|
||||
if (
|
||||
type === "postgres" ||
|
||||
type === "kingbase" ||
|
||||
type === "highgo" ||
|
||||
type === "vastbase" ||
|
||||
type === "opengauss"
|
||||
) {
|
||||
if (isPostgresCompatibleSSLType(type)) {
|
||||
params.set("sslmode", "disable");
|
||||
} else if (type === "sqlserver") {
|
||||
params.set("encrypt", "disable");
|
||||
@@ -2182,6 +2350,34 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectCertificateFile = async (
|
||||
fieldName: "sslCAPath" | "sslCertPath" | "sslKeyPath",
|
||||
certKind: "ca" | "client-cert" | "client-key",
|
||||
) => {
|
||||
if (selectingCertificateField) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
setSelectingCertificateField(fieldName);
|
||||
const currentPath = String(form.getFieldValue(fieldName) || "").trim();
|
||||
const res = await SelectCertificateFile(currentPath, certKind);
|
||||
if (res?.success) {
|
||||
const data = res.data || {};
|
||||
const selectedPath =
|
||||
typeof data === "string" ? data : String(data.path || "").trim();
|
||||
if (selectedPath) {
|
||||
form.setFieldValue(fieldName, selectedPath);
|
||||
}
|
||||
} else if (res?.message !== "已取消") {
|
||||
message.error(`选择证书文件失败: ${res?.message || "未知错误"}`);
|
||||
}
|
||||
} catch (e: any) {
|
||||
message.error(`选择证书文件失败: ${e?.message || String(e)}`);
|
||||
} finally {
|
||||
setSelectingCertificateField(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectDatabaseFile = async () => {
|
||||
if (selectingDbFile) {
|
||||
return;
|
||||
@@ -2317,6 +2513,7 @@ const ConnectionModal: React.FC<{
|
||||
includeRedisDatabases: initialValues.includeRedisDatabases,
|
||||
useSSL: !!config.useSSL,
|
||||
sslMode: config.sslMode || "preferred",
|
||||
sslCAPath: config.sslCAPath || "",
|
||||
sslCertPath: config.sslCertPath || "",
|
||||
sslKeyPath: config.sslKeyPath || "",
|
||||
useSSH: config.useSSH,
|
||||
@@ -3166,6 +3363,9 @@ const ConnectionModal: React.FC<{
|
||||
? "disable"
|
||||
: "preferred";
|
||||
const effectiveUseSSL = sslCapableType && !!mergedValues.useSSL;
|
||||
const sslCAPath = sslCapableType
|
||||
? String(mergedValues.sslCAPath || "").trim()
|
||||
: "";
|
||||
const sslCertPath = sslCapableType
|
||||
? String(mergedValues.sslCertPath || "").trim()
|
||||
: "";
|
||||
@@ -3175,6 +3375,9 @@ const ConnectionModal: React.FC<{
|
||||
if (type === "dameng" && effectiveUseSSL && (!sslCertPath || !sslKeyPath)) {
|
||||
throw new Error("达梦启用 SSL 时必须填写证书路径与私钥路径");
|
||||
}
|
||||
if (effectiveUseSSL && supportsSSLClientCertificateForType(type) && (!!sslCertPath !== !!sslKeyPath)) {
|
||||
throw new Error("TLS 客户端证书与私钥路径需要同时填写");
|
||||
}
|
||||
|
||||
let primaryHost = "localhost";
|
||||
let primaryPort = defaultPort;
|
||||
@@ -3369,6 +3572,7 @@ const ConnectionModal: React.FC<{
|
||||
database: mergedValues.database || "",
|
||||
useSSL: effectiveUseSSL,
|
||||
sslMode: effectiveUseSSL ? sslMode : "disable",
|
||||
sslCAPath: sslCAPath,
|
||||
sslCertPath: sslCertPath,
|
||||
sslKeyPath: sslKeyPath,
|
||||
useSSH: !!mergedValues.useSSH,
|
||||
@@ -3438,6 +3642,7 @@ const ConnectionModal: React.FC<{
|
||||
database: "",
|
||||
useSSL: false,
|
||||
sslMode: undefined,
|
||||
sslCAPath: undefined,
|
||||
sslCertPath: undefined,
|
||||
sslKeyPath: undefined,
|
||||
useSSH: false,
|
||||
@@ -3501,6 +3706,7 @@ const ConnectionModal: React.FC<{
|
||||
database: "",
|
||||
useSSL: false,
|
||||
sslMode: "preferred",
|
||||
sslCAPath: "",
|
||||
sslCertPath: "",
|
||||
sslKeyPath: "",
|
||||
useSSH: false,
|
||||
@@ -3551,6 +3757,7 @@ const ConnectionModal: React.FC<{
|
||||
port: defaultPort,
|
||||
useSSL: sslCapableType ? false : undefined,
|
||||
sslMode: sslCapableType ? "preferred" : undefined,
|
||||
sslCAPath: sslCapableType ? "" : undefined,
|
||||
sslCertPath: sslCapableType ? "" : undefined,
|
||||
sslKeyPath: sslCapableType ? "" : undefined,
|
||||
useHttpTunnel: false,
|
||||
@@ -5556,41 +5763,84 @@ const ConnectionModal: React.FC<{
|
||||
],
|
||||
})}
|
||||
</div>
|
||||
{dbType === "dameng" && (
|
||||
<>
|
||||
<Form.Item
|
||||
name="sslCertPath"
|
||||
label="客户端证书路径 (SSL_CERT_PATH)"
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: "达梦 SSL 需要证书路径",
|
||||
},
|
||||
]}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
<Input
|
||||
{...noAutoCapInputProps}
|
||||
placeholder="例如: C:\certs\client-cert.pem"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="sslKeyPath"
|
||||
label="客户端私钥路径 (SSL_KEY_PATH)"
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: "达梦 SSL 需要私钥路径",
|
||||
},
|
||||
]}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
<Input
|
||||
{...noAutoCapInputProps}
|
||||
placeholder="例如: C:\certs\client-key.pem"
|
||||
/>
|
||||
</Form.Item>
|
||||
</>
|
||||
{(supportsSSLCAPath || supportsSSLClientCertificate) && (
|
||||
<div style={{ display: "grid", gap: 8, marginBottom: 12 }}>
|
||||
{supportsSSLCAPath && (
|
||||
<Form.Item
|
||||
label={dbType === "sqlserver" ? "服务端证书/CA 路径" : "CA 证书路径"}
|
||||
style={{ marginBottom: 0 }}
|
||||
>
|
||||
<Space.Compact style={{ width: "100%" }}>
|
||||
<Form.Item name="sslCAPath" noStyle>
|
||||
<Input
|
||||
{...noAutoCapInputProps}
|
||||
placeholder="例如: C:\certs\ca.pem"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Button
|
||||
onClick={() => handleSelectCertificateFile("sslCAPath", "ca")}
|
||||
loading={selectingCertificateField === "sslCAPath"}
|
||||
>
|
||||
浏览...
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
</Form.Item>
|
||||
)}
|
||||
{supportsSSLClientCertificate && (
|
||||
<>
|
||||
<Form.Item
|
||||
label={dbType === "dameng" ? "客户端证书路径 (SSL_CERT_PATH)" : "客户端证书路径"}
|
||||
rules={[
|
||||
{
|
||||
required: dbType === "dameng",
|
||||
message: "达梦 SSL 需要证书路径",
|
||||
},
|
||||
]}
|
||||
style={{ marginBottom: 0 }}
|
||||
>
|
||||
<Space.Compact style={{ width: "100%" }}>
|
||||
<Form.Item name="sslCertPath" noStyle>
|
||||
<Input
|
||||
{...noAutoCapInputProps}
|
||||
placeholder="例如: C:\certs\client-cert.pem"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Button
|
||||
onClick={() => handleSelectCertificateFile("sslCertPath", "client-cert")}
|
||||
loading={selectingCertificateField === "sslCertPath"}
|
||||
>
|
||||
浏览...
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={dbType === "dameng" ? "客户端私钥路径 (SSL_KEY_PATH)" : "客户端私钥路径"}
|
||||
rules={[
|
||||
{
|
||||
required: dbType === "dameng",
|
||||
message: "达梦 SSL 需要私钥路径",
|
||||
},
|
||||
]}
|
||||
style={{ marginBottom: 0 }}
|
||||
>
|
||||
<Space.Compact style={{ width: "100%" }}>
|
||||
<Form.Item name="sslKeyPath" noStyle>
|
||||
<Input
|
||||
{...noAutoCapInputProps}
|
||||
placeholder="例如: C:\certs\client-key.pem"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Button
|
||||
onClick={() => handleSelectCertificateFile("sslKeyPath", "client-key")}
|
||||
loading={selectingCertificateField === "sslKeyPath"}
|
||||
>
|
||||
浏览...
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
</Form.Item>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||
{sslHintText}
|
||||
@@ -6216,6 +6466,7 @@ const ConnectionModal: React.FC<{
|
||||
user: "root",
|
||||
useSSL: false,
|
||||
sslMode: "preferred",
|
||||
sslCAPath: "",
|
||||
sslCertPath: "",
|
||||
sslKeyPath: "",
|
||||
useSSH: false,
|
||||
|
||||
@@ -262,6 +262,34 @@ describe('store appearance persistence', () => {
|
||||
expect(config?.port).toBe(9030);
|
||||
});
|
||||
|
||||
it('preserves SSL certificate paths for SSL-capable saved connections', async () => {
|
||||
const { useStore } = await importStore();
|
||||
|
||||
useStore.getState().replaceConnections([
|
||||
{
|
||||
id: 'postgres-ssl',
|
||||
name: 'Postgres SSL',
|
||||
config: {
|
||||
id: 'postgres-ssl',
|
||||
type: 'postgres',
|
||||
host: 'db.local',
|
||||
port: 5432,
|
||||
user: 'postgres',
|
||||
useSSL: true,
|
||||
sslMode: 'required',
|
||||
sslCAPath: 'C:/certs/ca.pem',
|
||||
sslCertPath: 'C:/certs/client-cert.pem',
|
||||
sslKeyPath: 'C:/certs/client-key.pem',
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const config = useStore.getState().connections[0]?.config;
|
||||
expect(config?.sslCAPath).toBe('C:/certs/ca.pem');
|
||||
expect(config?.sslCertPath).toBe('C:/certs/client-cert.pem');
|
||||
expect(config?.sslKeyPath).toBe('C:/certs/client-key.pem');
|
||||
});
|
||||
|
||||
it('normalizes OceanBase protocol override when replacing saved connections', async () => {
|
||||
const { useStore } = await importStore();
|
||||
|
||||
|
||||
@@ -529,6 +529,7 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
database: toTrimmedString(raw.database),
|
||||
useSSL: sslCapable ? !!raw.useSSL : false,
|
||||
sslMode: sslCapable ? sslMode : "disable",
|
||||
sslCAPath: sslCapable ? toTrimmedString(raw.sslCAPath) : "",
|
||||
sslCertPath: sslCapable ? toTrimmedString(raw.sslCertPath) : "",
|
||||
sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : "",
|
||||
useSSH: !!raw.useSSH,
|
||||
|
||||
@@ -284,6 +284,7 @@ export interface ConnectionConfig {
|
||||
database?: string;
|
||||
useSSL?: boolean;
|
||||
sslMode?: "preferred" | "required" | "skip-verify" | "disable";
|
||||
sslCAPath?: string;
|
||||
sslCertPath?: string;
|
||||
sslKeyPath?: string;
|
||||
useSSH?: boolean;
|
||||
|
||||
@@ -148,6 +148,27 @@ describe('buildRpcConnectionConfig', () => {
|
||||
expect(result.connectionParams).toBe('characterEncoding=utf8&useSSL=false');
|
||||
});
|
||||
|
||||
it('preserves SSL certificate path fields for RPC calls', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-postgres-ssl',
|
||||
type: 'postgres',
|
||||
host: 'db.local',
|
||||
port: 5432,
|
||||
user: 'postgres',
|
||||
useSSL: true,
|
||||
sslMode: 'required',
|
||||
sslCAPath: 'C:/certs/ca.pem',
|
||||
sslCertPath: 'C:/certs/client-cert.pem',
|
||||
sslKeyPath: 'C:/certs/client-key.pem',
|
||||
} as any);
|
||||
|
||||
expect(result.useSSL).toBe(true);
|
||||
expect(result.sslMode).toBe('required');
|
||||
expect(result.sslCAPath).toBe('C:/certs/ca.pem');
|
||||
expect(result.sslCertPath).toBe('C:/certs/client-cert.pem');
|
||||
expect(result.sslKeyPath).toBe('C:/certs/client-key.pem');
|
||||
});
|
||||
|
||||
it('fills default nested config blocks needed by RPC calls', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-redis',
|
||||
|
||||
4
frontend/wailsjs/go/app/App.d.ts
vendored
4
frontend/wailsjs/go/app/App.d.ts
vendored
@@ -242,6 +242,8 @@ export function RenameTable(arg1:connection.ConnectionConfig,arg2:string,arg3:st
|
||||
|
||||
export function RenameView(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ResetWebViewZoom():Promise<connection.QueryResult>;
|
||||
|
||||
export function ResolveDriverDownloadDirectory(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ResolveDriverPackageDownloadURL(arg1:string,arg2:string):Promise<connection.QueryResult>;
|
||||
@@ -256,6 +258,8 @@ export function SaveConnection(arg1:connection.SavedConnectionInput):Promise<con
|
||||
|
||||
export function SaveGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise<connection.GlobalProxyView>;
|
||||
|
||||
export function SelectCertificateFile(arg1:string,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function SelectDataRootDirectory(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function SelectDatabaseFile(arg1:string,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -474,6 +474,10 @@ export function RenameView(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['RenameView'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ResetWebViewZoom() {
|
||||
return window['go']['app']['App']['ResetWebViewZoom']();
|
||||
}
|
||||
|
||||
export function ResolveDriverDownloadDirectory(arg1) {
|
||||
return window['go']['app']['App']['ResolveDriverDownloadDirectory'](arg1);
|
||||
}
|
||||
@@ -502,6 +506,10 @@ export function SaveGlobalProxy(arg1) {
|
||||
return window['go']['app']['App']['SaveGlobalProxy'](arg1);
|
||||
}
|
||||
|
||||
export function SelectCertificateFile(arg1, arg2) {
|
||||
return window['go']['app']['App']['SelectCertificateFile'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function SelectDataRootDirectory(arg1) {
|
||||
return window['go']['app']['App']['SelectDataRootDirectory'](arg1);
|
||||
}
|
||||
|
||||
@@ -673,6 +673,7 @@ export namespace connection {
|
||||
database: string;
|
||||
useSSL?: boolean;
|
||||
sslMode?: string;
|
||||
sslCAPath?: string;
|
||||
sslCertPath?: string;
|
||||
sslKeyPath?: string;
|
||||
useSSH: boolean;
|
||||
@@ -718,6 +719,7 @@ export namespace connection {
|
||||
this.database = source["database"];
|
||||
this.useSSL = source["useSSL"];
|
||||
this.sslMode = source["sslMode"];
|
||||
this.sslCAPath = source["sslCAPath"];
|
||||
this.sslCertPath = source["sslCertPath"];
|
||||
this.sslKeyPath = source["sslKeyPath"];
|
||||
this.useSSH = source["useSSH"];
|
||||
|
||||
@@ -553,6 +553,61 @@ func (a *App) SelectSSHKeyFile(currentPath string) connection.QueryResult {
|
||||
return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}}
|
||||
}
|
||||
|
||||
func (a *App) SelectCertificateFile(currentPath string, certKind string) connection.QueryResult {
|
||||
defaultDir := strings.TrimSpace(currentPath)
|
||||
if defaultDir == "" {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
defaultDir = home
|
||||
}
|
||||
}
|
||||
if filepath.Ext(defaultDir) != "" {
|
||||
defaultDir = filepath.Dir(defaultDir)
|
||||
}
|
||||
if defaultDir != "" && !filepath.IsAbs(defaultDir) {
|
||||
if abs, err := filepath.Abs(defaultDir); err == nil {
|
||||
defaultDir = abs
|
||||
}
|
||||
}
|
||||
|
||||
kind := strings.ToLower(strings.TrimSpace(certKind))
|
||||
title := "选择 TLS 证书文件"
|
||||
displayName := "证书文件"
|
||||
switch kind {
|
||||
case "ca":
|
||||
title = "选择 CA/服务端证书文件"
|
||||
case "client-cert":
|
||||
title = "选择客户端证书文件"
|
||||
case "client-key":
|
||||
title = "选择客户端私钥文件"
|
||||
displayName = "私钥文件"
|
||||
}
|
||||
|
||||
selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{
|
||||
Title: title,
|
||||
DefaultDirectory: defaultDir,
|
||||
Filters: []runtime.FileFilter{
|
||||
{
|
||||
DisplayName: displayName,
|
||||
Pattern: "*.pem;*.crt;*.cer;*.cert;*.key",
|
||||
},
|
||||
{
|
||||
DisplayName: "所有文件",
|
||||
Pattern: "*",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
if strings.TrimSpace(selection) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "已取消"}
|
||||
}
|
||||
if abs, err := filepath.Abs(selection); err == nil {
|
||||
selection = abs
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}}
|
||||
}
|
||||
|
||||
func (a *App) SelectDatabaseFile(currentPath string, driverType string) connection.QueryResult {
|
||||
defaultDir := strings.TrimSpace(currentPath)
|
||||
if defaultDir == "" {
|
||||
|
||||
@@ -89,6 +89,7 @@ type ConnectionConfig struct {
|
||||
Database string `json:"database"`
|
||||
UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch
|
||||
SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable
|
||||
SSLCAPath string `json:"sslCAPath,omitempty"` // TLS root CA / server certificate path
|
||||
SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng)
|
||||
SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng)
|
||||
UseSSH bool `json:"useSSH"`
|
||||
|
||||
@@ -167,7 +167,7 @@ func defaultClickHousePortForScheme(scheme string) int {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) *clickhouse.Options {
|
||||
func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) (*clickhouse.Options, error) {
|
||||
connectTimeout := getConnectTimeout(config)
|
||||
readTimeout := connectTimeout
|
||||
if readTimeout < minClickHouseReadTimeout {
|
||||
@@ -187,11 +187,15 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
|
||||
DialTimeout: connectTimeout,
|
||||
ReadTimeout: readTimeout,
|
||||
}
|
||||
if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil {
|
||||
tlsConfig, err := resolveGenericTLSConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
opts.TLS = tlsConfig
|
||||
}
|
||||
applyClickHouseConnectionParams(opts, config)
|
||||
return opts
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func parseClickHouseDurationParam(raw string) (time.Duration, bool) {
|
||||
@@ -549,7 +553,14 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
protocolConfig := withClickHouseProtocol(attempt, protocol)
|
||||
logger.Infof("ClickHouse 连接尝试:第%d组/%d 协议=%s 地址=%s:%d SSL=%t",
|
||||
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL)
|
||||
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(protocolConfig))
|
||||
opts, err := c.buildClickHouseOptions(protocolConfig)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败(protocol=%s): %v", idx+1, protocol.String(), err))
|
||||
logger.Warnf("ClickHouse TLS 配置失败:第%d组/%d 协议=%s 地址=%s:%d SSL=%t 原因=%v",
|
||||
idx+1, len(attempts), clickHouseProtocolName(protocol), protocolConfig.Host, protocolConfig.Port, protocolConfig.UseSSL, err)
|
||||
continue
|
||||
}
|
||||
c.conn = clickhouse.OpenDB(opts)
|
||||
if err := c.Ping(); err != nil {
|
||||
failureMessage := clickHouseAttemptFailureMessage(protocol, err)
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败(protocol=%s): %s", idx+1, protocol.String(), failureMessage))
|
||||
|
||||
@@ -139,7 +139,7 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database)
|
||||
}
|
||||
|
||||
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
|
||||
@@ -55,6 +55,78 @@ func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresDSN_AppendsSSLPathParams(t *testing.T) {
|
||||
p := &PostgresDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5432,
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
SSLCAPath: "C:\\certs\\ca.pem",
|
||||
SSLCertPath: "C:\\certs\\client-cert.pem",
|
||||
SSLKeyPath: "C:\\certs\\client-key.pem",
|
||||
}
|
||||
|
||||
dsn := p.getDSN(cfg)
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("parse postgres dsn: %v", err)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("sslmode"); got != "verify-ca" {
|
||||
t.Fatalf("sslmode = %q, want verify-ca", got)
|
||||
}
|
||||
if got := query.Get("sslrootcert"); got != cfg.SSLCAPath {
|
||||
t.Fatalf("sslrootcert = %q, want %q", got, cfg.SSLCAPath)
|
||||
}
|
||||
if got := query.Get("sslcert"); got != cfg.SSLCertPath {
|
||||
t.Fatalf("sslcert = %q, want %q", got, cfg.SSLCertPath)
|
||||
}
|
||||
if got := query.Get("sslkey"); got != cfg.SSLKeyPath {
|
||||
t.Fatalf("sslkey = %q, want %q", got, cfg.SSLKeyPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresDSN_SkipVerifyOmitsSSLRootCert(t *testing.T) {
|
||||
p := &PostgresDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5432,
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
UseSSL: true,
|
||||
SSLMode: "skip-verify",
|
||||
SSLCAPath: "C:\\certs\\ca.pem",
|
||||
SSLCertPath: "C:\\certs\\client-cert.pem",
|
||||
SSLKeyPath: "C:\\certs\\client-key.pem",
|
||||
}
|
||||
|
||||
dsn := p.getDSN(cfg)
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("parse postgres dsn: %v", err)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("sslmode"); got != "require" {
|
||||
t.Fatalf("sslmode = %q, want require", got)
|
||||
}
|
||||
if got := query.Get("sslrootcert"); got != "" {
|
||||
t.Fatalf("sslrootcert should be omitted for skip-verify, got %q", got)
|
||||
}
|
||||
if got := query.Get("sslcert"); got != cfg.SSLCertPath {
|
||||
t.Fatalf("sslcert = %q, want %q", got, cfg.SSLCertPath)
|
||||
}
|
||||
if got := query.Get("sslkey"); got != cfg.SSLKeyPath {
|
||||
t.Fatalf("sslkey = %q, want %q", got, cfg.SSLKeyPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
|
||||
p := &PostgresDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
@@ -109,6 +181,65 @@ func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_UsesCustomTLSConfigWhenCertificatePathsAreConfigured(t *testing.T) {
|
||||
m := &MySQLDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
SSLCAPath: "../../third_party/highgo-pq/certs/root.crt",
|
||||
SSLCertPath: "../../third_party/highgo-pq/certs/postgresql.crt",
|
||||
SSLKeyPath: "../../third_party/highgo-pq/certs/postgresql.key",
|
||||
}
|
||||
|
||||
dsn, err := m.getDSN(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
if strings.Contains(dsn, "tls=true") {
|
||||
t.Fatalf("dsn 应使用自定义 TLS 配置名而不是 tls=true:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "tls=gonavi-") {
|
||||
t.Fatalf("dsn 缺少自定义 TLS 配置名:%s", dsn)
|
||||
}
|
||||
if strings.Contains(dsn, "allowFallbackToPlaintext=true") {
|
||||
t.Fatalf("required 模式不应启用明文回退:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_PreservesPreferredFallbackWithCustomTLSConfig(t *testing.T) {
|
||||
m := &MySQLDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
UseSSL: true,
|
||||
SSLMode: "preferred",
|
||||
SSLCAPath: "../../third_party/highgo-pq/certs/root.crt",
|
||||
SSLCertPath: "../../third_party/highgo-pq/certs/postgresql.crt",
|
||||
SSLKeyPath: "../../third_party/highgo-pq/certs/postgresql.key",
|
||||
}
|
||||
|
||||
dsn, err := m.getDSN(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(dsn, "tls=gonavi-") {
|
||||
t.Fatalf("dsn 缺少自定义 TLS 配置名:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "allowFallbackToPlaintext=true") {
|
||||
t.Fatalf("preferred 自定义 TLS 配置应保留明文回退:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleDSN_EscapesUserAndPassword(t *testing.T) {
|
||||
o := &OracleDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
@@ -380,6 +511,30 @@ func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLServerDSN_AppendsCertificateParam(t *testing.T) {
|
||||
s := &SqlServerDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "sqlserver",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1433,
|
||||
User: "sa",
|
||||
Password: "pass",
|
||||
Database: "master",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
SSLCAPath: "C:\\certs\\sqlserver-ca.pem",
|
||||
}
|
||||
|
||||
dsn := s.getDSN(cfg)
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("parse sqlserver dsn: %v", err)
|
||||
}
|
||||
if got := parsed.Query().Get("certificate"); got != cfg.SSLCAPath {
|
||||
t.Fatalf("certificate = %q, want %q", got, cfg.SSLCAPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
|
||||
c := &ClickHouseDB{}
|
||||
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
@@ -392,7 +547,10 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
|
||||
Timeout: 15,
|
||||
})
|
||||
|
||||
opts := c.buildClickHouseOptions(cfg)
|
||||
opts, err := c.buildClickHouseOptions(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("buildClickHouseOptions failed: %v", err)
|
||||
}
|
||||
if opts == nil {
|
||||
t.Fatal("options 为空")
|
||||
}
|
||||
@@ -438,7 +596,10 @@ func TestClickHouseOptions_MergesConnectionParamsIntoOptionsAndSettings(t *testi
|
||||
ConnectionParams: "max_execution_time=60&compress=lz4&read_timeout=10s",
|
||||
})
|
||||
|
||||
opts := c.buildClickHouseOptions(cfg)
|
||||
opts, err := c.buildClickHouseOptions(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("buildClickHouseOptions failed: %v", err)
|
||||
}
|
||||
if opts == nil {
|
||||
t.Fatal("options 为空")
|
||||
}
|
||||
@@ -465,7 +626,10 @@ func TestClickHouseOptions_ReadTimeoutUsesLargerConfiguredTimeout(t *testing.T)
|
||||
Timeout: 900,
|
||||
})
|
||||
|
||||
opts := c.buildClickHouseOptions(cfg)
|
||||
opts, err := c.buildClickHouseOptions(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("buildClickHouseOptions failed: %v", err)
|
||||
}
|
||||
if opts == nil {
|
||||
t.Fatal("options 为空")
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
applyPostgresSSLPathParams(q, config)
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, highGoConnectionParamNames, "postgres", "postgresql", "highgo")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@@ -70,10 +70,11 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
params.Set("password", config.Password)
|
||||
params.Set("dbname", config.Database)
|
||||
params.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
applyPostgresSSLPathParams(params, config)
|
||||
params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfigWithAllowlist(params, config, kingbaseConnectionParamNames, "kingbase")
|
||||
|
||||
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"}
|
||||
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "sslrootcert", "sslcert", "sslkey", "connect_timeout"}
|
||||
seen := make(map[string]struct{}, len(params))
|
||||
parts := make([]string, 0, len(params))
|
||||
for _, key := range preferred {
|
||||
|
||||
@@ -36,7 +36,7 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database)
|
||||
}
|
||||
|
||||
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
@@ -4,7 +4,6 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -414,12 +413,15 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
uri := m.getURI(attemptConfig)
|
||||
clientOpts := options.Client().ApplyURI(uri)
|
||||
tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig)
|
||||
if tlsEnabled {
|
||||
clientOpts.SetTLSConfig(&tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: tlsInsecure,
|
||||
})
|
||||
tlsConfig, tlsErr := resolveGenericTLSConfig(attemptConfig)
|
||||
if tlsErr != nil {
|
||||
detail := fmt.Sprintf("%s %sTLS 配置失败: %v", sslLabel, authLabel, tlsErr)
|
||||
errorDetails = append(errorDetails, detail)
|
||||
logger.Warnf("MongoDB TLS 配置失败:%d/%d 模式=%s 凭据=%s 错误=%v", attemptNo, totalAttempts, sslLabel, authLabel, tlsErr)
|
||||
continue
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
clientOpts.SetTLSConfig(tlsConfig)
|
||||
}
|
||||
if attemptConfig.UseProxy {
|
||||
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
|
||||
|
||||
@@ -4,7 +4,6 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -415,12 +414,15 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
uri := m.getURI(attemptConfig)
|
||||
clientOpts := options.Client().ApplyURI(uri)
|
||||
tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig)
|
||||
if tlsEnabled {
|
||||
clientOpts.SetTLSConfig(&tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: tlsInsecure,
|
||||
})
|
||||
tlsConfig, tlsErr := resolveGenericTLSConfig(attemptConfig)
|
||||
if tlsErr != nil {
|
||||
detail := fmt.Sprintf("%s %sTLS 配置失败: %v", sslLabel, authLabel, tlsErr)
|
||||
errorDetails = append(errorDetails, detail)
|
||||
logger.Warnf("MongoDB TLS 配置失败:%d/%d 模式=%s 凭据=%s 错误=%v", attemptNo, totalAttempts, sslLabel, authLabel, tlsErr)
|
||||
continue
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
clientOpts.SetTLSConfig(tlsConfig)
|
||||
}
|
||||
if attemptConfig.UseProxy {
|
||||
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
mysql "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type MySQLDB struct {
|
||||
@@ -115,19 +115,19 @@ var mysqlSupportedDriverParamNames = map[string]string{
|
||||
// OceanBase Oracle 租户 MySQL wire 路径用它注入 OBClient 私有 capability attribute;
|
||||
// 普通 mysql/mariadb 用户也能在此声明 program_name 等元数据。
|
||||
"connectionattributes": "connectionAttributes",
|
||||
"interpolateparams": "interpolateParams",
|
||||
"loc": "loc",
|
||||
"maxallowedpacket": "maxAllowedPacket",
|
||||
"multistatements": "multiStatements",
|
||||
"parsetime": "parseTime",
|
||||
"readtimeout": "readTimeout",
|
||||
"rejectreadonly": "rejectReadOnly",
|
||||
"serverpubkey": "serverPubKey",
|
||||
"sql_mode": "sql_mode",
|
||||
"timetruncate": "timeTruncate",
|
||||
"timeout": "timeout",
|
||||
"tls": "tls",
|
||||
"writetimeout": "writeTimeout",
|
||||
"interpolateparams": "interpolateParams",
|
||||
"loc": "loc",
|
||||
"maxallowedpacket": "maxAllowedPacket",
|
||||
"multistatements": "multiStatements",
|
||||
"parsetime": "parseTime",
|
||||
"readtimeout": "readTimeout",
|
||||
"rejectreadonly": "rejectReadOnly",
|
||||
"serverpubkey": "serverPubKey",
|
||||
"sql_mode": "sql_mode",
|
||||
"timetruncate": "timeTruncate",
|
||||
"timeout": "timeout",
|
||||
"tls": "tls",
|
||||
"writetimeout": "writeTimeout",
|
||||
}
|
||||
|
||||
var mysqlBoolDriverParamNames = map[string]struct{}{
|
||||
@@ -274,15 +274,40 @@ func mergeMySQLConnectionParams(params url.Values, values url.Values) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string {
|
||||
func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, error) {
|
||||
mode := resolveMySQLTLSMode(config)
|
||||
if mode == "false" || !hasTLSCertificatePaths(config) {
|
||||
return mode, false, nil
|
||||
}
|
||||
tlsConfig, err := resolveGenericTLSConfig(config)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if tlsConfig == nil {
|
||||
return mode, false, nil
|
||||
}
|
||||
name := mysqlTLSConfigName(config)
|
||||
if err := mysql.RegisterTLSConfig(name, tlsConfig); err != nil && !strings.Contains(strings.ToLower(err.Error()), "already registered") {
|
||||
return "", false, fmt.Errorf("注册 MySQL TLS 证书配置失败:%w", err)
|
||||
}
|
||||
return name, normalizeSSLModeValue(config.SSLMode) == sslModePreferred, nil
|
||||
}
|
||||
|
||||
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
tlsMode, allowFallbackToPlaintext, err := resolveMySQLTLSParam(config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("charset", "utf8mb4")
|
||||
params.Set("parseTime", "True")
|
||||
params.Set("loc", "Local")
|
||||
params.Set("timeout", fmt.Sprintf("%ds", timeout))
|
||||
params.Set("tls", tlsMode)
|
||||
if allowFallbackToPlaintext {
|
||||
params.Set("allowFallbackToPlaintext", "true")
|
||||
}
|
||||
params.Set("multiStatements", "true")
|
||||
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok {
|
||||
mergeMySQLConnectionParams(params, parsed.Query())
|
||||
@@ -291,7 +316,7 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?%s",
|
||||
config.User, config.Password, protocol, address, database, params.Encode(),
|
||||
)
|
||||
), nil
|
||||
}
|
||||
|
||||
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
|
||||
@@ -502,7 +527,7 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database)
|
||||
}
|
||||
|
||||
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
|
||||
@@ -185,7 +185,7 @@ func (o *OceanBaseDB) getDSN(config connection.ConnectionConfig) (string, error)
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database)
|
||||
}
|
||||
|
||||
func normalizeOceanBaseProtocol(raw string) string {
|
||||
|
||||
@@ -63,6 +63,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
applyPostgresSSLPathParams(q, config)
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "opengauss")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@@ -50,6 +50,9 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
|
||||
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
|
||||
q.Set("encrypt", encrypt)
|
||||
q.Set("trustservercertificate", trustServerCertificate)
|
||||
if strings.TrimSpace(config.SSLCAPath) != "" {
|
||||
q.Set("certificate", strings.TrimSpace(config.SSLCAPath))
|
||||
}
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, sqlServerConnectionParamNames, "sqlserver")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/tlsconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,15 +64,37 @@ func resolveMySQLTLSMode(config connection.ConnectionConfig) string {
|
||||
}
|
||||
}
|
||||
|
||||
func hasTLSCertificatePaths(config connection.ConnectionConfig) bool {
|
||||
return strings.TrimSpace(config.SSLCAPath) != "" ||
|
||||
strings.TrimSpace(config.SSLCertPath) != "" ||
|
||||
strings.TrimSpace(config.SSLKeyPath) != ""
|
||||
}
|
||||
|
||||
func mysqlTLSConfigName(config connection.ConnectionConfig) string {
|
||||
sum := sha256.Sum256([]byte(strings.Join([]string{
|
||||
normalizedSSLMode(config),
|
||||
strings.TrimSpace(config.SSLCAPath),
|
||||
strings.TrimSpace(config.SSLCertPath),
|
||||
strings.TrimSpace(config.SSLKeyPath),
|
||||
}, "\x00")))
|
||||
return "gonavi-" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func resolvePostgresSSLMode(config connection.ConnectionConfig) string {
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeDisable:
|
||||
return "disable"
|
||||
case sslModeRequired:
|
||||
if strings.TrimSpace(config.SSLCAPath) != "" {
|
||||
return "verify-ca"
|
||||
}
|
||||
return "require"
|
||||
case sslModeSkipVerify:
|
||||
return "require"
|
||||
default:
|
||||
if strings.TrimSpace(config.SSLCAPath) != "" {
|
||||
return "verify-ca"
|
||||
}
|
||||
return "require"
|
||||
}
|
||||
}
|
||||
@@ -87,17 +112,47 @@ func resolveSQLServerTLSSettings(config connection.ConnectionConfig) (encrypt st
|
||||
}
|
||||
}
|
||||
|
||||
func resolveGenericTLSConfig(config connection.ConnectionConfig) *tls.Config {
|
||||
func applyPostgresSSLPathParams(params interface{ Set(string, string) }, config connection.ConnectionConfig) {
|
||||
mode := normalizedSSLMode(config)
|
||||
if mode != sslModeDisable && mode != sslModeSkipVerify && strings.TrimSpace(config.SSLCAPath) != "" {
|
||||
params.Set("sslrootcert", strings.TrimSpace(config.SSLCAPath))
|
||||
}
|
||||
if mode != sslModeDisable && strings.TrimSpace(config.SSLCertPath) != "" {
|
||||
params.Set("sslcert", strings.TrimSpace(config.SSLCertPath))
|
||||
}
|
||||
if mode != sslModeDisable && strings.TrimSpace(config.SSLKeyPath) != "" {
|
||||
params.Set("sslkey", strings.TrimSpace(config.SSLKeyPath))
|
||||
}
|
||||
}
|
||||
|
||||
func resolveGenericTLSConfig(config connection.ConnectionConfig) (*tls.Config, error) {
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeDisable:
|
||||
return nil
|
||||
return nil, nil
|
||||
case sslModeRequired:
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
|
||||
Enabled: true,
|
||||
CAPath: config.SSLCAPath,
|
||||
CertPath: config.SSLCertPath,
|
||||
KeyPath: config.SSLKeyPath,
|
||||
})
|
||||
case sslModeSkipVerify:
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
|
||||
Enabled: true,
|
||||
InsecureSkipVerify: true,
|
||||
CAPath: config.SSLCAPath,
|
||||
CertPath: config.SSLCertPath,
|
||||
KeyPath: config.SSLKeyPath,
|
||||
})
|
||||
default:
|
||||
// Preferred: 先尝试 TLS(为提升兼容性默认跳过证书校验),失败时由调用方按需回退明文。
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
|
||||
Enabled: true,
|
||||
InsecureSkipVerify: true,
|
||||
CAPath: config.SSLCAPath,
|
||||
CertPath: config.SSLCertPath,
|
||||
KeyPath: config.SSLKeyPath,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ func (s *StarRocksDB) getDSN(config connection.ConnectionConfig) (string, error)
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database)
|
||||
}
|
||||
|
||||
func resolveStarRocksCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
|
||||
@@ -42,6 +42,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
applyPostgresSSLPathParams(q, config)
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "vastbase")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@@ -277,7 +277,10 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
|
||||
if cfg, err := resolveRedisTLSConfig(attempt); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败: %v", idx+1, err))
|
||||
continue
|
||||
} else if cfg != nil {
|
||||
if host, _, err := net.SplitHostPort(seedAddrs[0]); err == nil && host != "" {
|
||||
cfg.ServerName = host
|
||||
}
|
||||
@@ -332,7 +335,10 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
|
||||
if cfg, err := resolveRedisTLSConfig(attempt); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次 TLS 配置失败: %v", idx+1, err))
|
||||
continue
|
||||
} else if cfg != nil {
|
||||
if host, _, err := net.SplitHostPort(addr); err == nil && host != "" {
|
||||
cfg.ServerName = host
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/tlsconfig"
|
||||
)
|
||||
|
||||
func normalizeRedisSSLMode(raw string) string {
|
||||
@@ -41,15 +42,32 @@ func withRedisSSLDisabled(config connection.ConnectionConfig) connection.Connect
|
||||
return next
|
||||
}
|
||||
|
||||
func resolveRedisTLSConfig(config connection.ConnectionConfig) *tls.Config {
|
||||
func resolveRedisTLSConfig(config connection.ConnectionConfig) (*tls.Config, error) {
|
||||
switch redisSSLMode(config) {
|
||||
case "disable":
|
||||
return nil
|
||||
return nil, nil
|
||||
case "required":
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
|
||||
Enabled: true,
|
||||
CAPath: config.SSLCAPath,
|
||||
CertPath: config.SSLCertPath,
|
||||
KeyPath: config.SSLKeyPath,
|
||||
})
|
||||
case "skip-verify":
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
|
||||
Enabled: true,
|
||||
InsecureSkipVerify: true,
|
||||
CAPath: config.SSLCAPath,
|
||||
CertPath: config.SSLCertPath,
|
||||
KeyPath: config.SSLKeyPath,
|
||||
})
|
||||
default:
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
return tlsconfig.BuildClientConfig(tlsconfig.ClientConfigOptions{
|
||||
Enabled: true,
|
||||
InsecureSkipVerify: true,
|
||||
CAPath: config.SSLCAPath,
|
||||
CertPath: config.SSLCertPath,
|
||||
KeyPath: config.SSLKeyPath,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
62
internal/tlsconfig/tlsconfig.go
Normal file
62
internal/tlsconfig/tlsconfig.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package tlsconfig
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ClientConfigOptions struct {
|
||||
Enabled bool
|
||||
InsecureSkipVerify bool
|
||||
CAPath string
|
||||
CertPath string
|
||||
KeyPath string
|
||||
}
|
||||
|
||||
func BuildClientConfig(options ClientConfigOptions) (*tls.Config, error) {
|
||||
if !options.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: options.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
caPath := strings.TrimSpace(options.CAPath)
|
||||
if caPath != "" {
|
||||
pemBytes, err := os.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取 TLS CA 证书失败(%s):%w", caPath, err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if ok := pool.AppendCertsFromPEM(pemBytes); !ok {
|
||||
certs, err := x509.ParseCertificates(pemBytes)
|
||||
if err != nil || len(certs) == 0 {
|
||||
return nil, fmt.Errorf("TLS CA 证书不是有效的 PEM/DER 文件:%s", caPath)
|
||||
}
|
||||
for _, cert := range certs {
|
||||
pool.AddCert(cert)
|
||||
}
|
||||
}
|
||||
cfg.RootCAs = pool
|
||||
}
|
||||
|
||||
certPath := strings.TrimSpace(options.CertPath)
|
||||
keyPath := strings.TrimSpace(options.KeyPath)
|
||||
if (certPath == "") != (keyPath == "") {
|
||||
return nil, fmt.Errorf("TLS 客户端证书和私钥需要同时配置")
|
||||
}
|
||||
if certPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载 TLS 客户端证书失败(cert=%s key=%s):%w", certPath, keyPath, err)
|
||||
}
|
||||
cfg.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
126
internal/tlsconfig/tlsconfig_test.go
Normal file
126
internal/tlsconfig/tlsconfig_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package tlsconfig
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBuildClientConfigLoadsCAAndClientCertificate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPath, keyPath, _ := writeSelfSignedCertificate(t, dir)
|
||||
|
||||
cfg, err := BuildClientConfig(ClientConfigOptions{
|
||||
Enabled: true,
|
||||
CAPath: certPath,
|
||||
CertPath: certPath,
|
||||
KeyPath: keyPath,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildClientConfig failed: %v", err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("config is nil")
|
||||
}
|
||||
if cfg.RootCAs == nil {
|
||||
t.Fatal("RootCAs is nil")
|
||||
}
|
||||
if len(cfg.Certificates) != 1 {
|
||||
t.Fatalf("Certificates length = %d, want 1", len(cfg.Certificates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClientConfigLoadsDERCA(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
_, _, derBytes := writeSelfSignedCertificate(t, dir)
|
||||
derPath := filepath.Join(dir, "ca.cer")
|
||||
if err := os.WriteFile(derPath, derBytes, 0600); err != nil {
|
||||
t.Fatalf("write der cert: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := BuildClientConfig(ClientConfigOptions{
|
||||
Enabled: true,
|
||||
CAPath: derPath,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildClientConfig failed: %v", err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("config is nil")
|
||||
}
|
||||
if cfg.RootCAs == nil {
|
||||
t.Fatal("RootCAs is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClientConfigRequiresClientCertificateAndKeyTogether(t *testing.T) {
|
||||
_, err := BuildClientConfig(ClientConfigOptions{
|
||||
Enabled: true,
|
||||
CertPath: "client.pem",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func writeSelfSignedCertificate(t *testing.T, dir string) (string, string, []byte) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "GoNavi Test",
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("create certificate: %v", err)
|
||||
}
|
||||
|
||||
certPath := filepath.Join(dir, "cert.pem")
|
||||
certFile, err := os.Create(certPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create cert file: %v", err)
|
||||
}
|
||||
if err := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
_ = certFile.Close()
|
||||
t.Fatalf("write cert: %v", err)
|
||||
}
|
||||
if err := certFile.Close(); err != nil {
|
||||
t.Fatalf("close cert file: %v", err)
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(dir, "key.pem")
|
||||
keyFile, err := os.Create(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create key file: %v", err)
|
||||
}
|
||||
keyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
if err := pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}); err != nil {
|
||||
_ = keyFile.Close()
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
if err := keyFile.Close(); err != nil {
|
||||
t.Fatalf("close key file: %v", err)
|
||||
}
|
||||
|
||||
return certPath, keyPath, derBytes
|
||||
}
|
||||
Reference in New Issue
Block a user