mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-06 20:03:05 +08:00
✨ feat(connection): 支持多数据源额外连接参数配置
- 前端连接表单新增额外连接参数入口,支持 URI query 格式录入与解析回填 - MySQL 兼容驱动支持 JDBC 常见参数映射,修复 UTF-8 字符集与 serverTimezone 兼容问题 - 扩展 Oracle、PostgreSQL 兼容、SQL Server、ClickHouse、MongoDB、达梦、TDengine 参数合并 - 按不同驱动通道处理 DSN、URI、Options 与 Settings,避免统一透传导致连接异常 - 修复编辑已保存连接时解析无认证 URI 会清空已有账号密码的问题 - 补充连接参数透传、缓存隔离、DSN 合并与 URI 回填回归测试
This commit is contained in:
@@ -61,6 +61,7 @@ import {
|
||||
} from "../utils/connectionModalPresentation";
|
||||
import { resolveConnectionSecretDraft } from "../utils/connectionSecretDraft";
|
||||
import { getCustomConnectionDsnValidationMessage } from "../utils/customConnectionDsn";
|
||||
import { mergeParsedUriValuesForForm } from "../utils/connectionUriMerge";
|
||||
import { CUSTOM_CONNECTION_DRIVER_HELP } from "../utils/driverImportGuidance";
|
||||
import {
|
||||
applyNoAutoCapAttributes,
|
||||
@@ -97,6 +98,7 @@ type ChoiceCardOption = {
|
||||
};
|
||||
type ClickHouseProtocolChoice = "auto" | "http" | "native";
|
||||
const MAX_URI_LENGTH = 4096;
|
||||
const MAX_CONNECTION_PARAMS_LENGTH = 4096;
|
||||
const MAX_URI_HOSTS = 32;
|
||||
const MAX_TIMEOUT_SECONDS = 3600;
|
||||
const CONNECTION_MODAL_WIDTH = 960;
|
||||
@@ -232,6 +234,26 @@ const supportsSSLForType = (type: string) =>
|
||||
const isFileDatabaseType = (type: string) =>
|
||||
type === "sqlite" || type === "duckdb";
|
||||
|
||||
const isMySQLCompatibleType = (type: string) =>
|
||||
type === "mysql" ||
|
||||
type === "mariadb" ||
|
||||
type === "doris" ||
|
||||
type === "diros" ||
|
||||
type === "sphinx";
|
||||
|
||||
const supportsConnectionParamsForType = (type: string) =>
|
||||
isMySQLCompatibleType(type) ||
|
||||
type === "postgres" ||
|
||||
type === "kingbase" ||
|
||||
type === "highgo" ||
|
||||
type === "vastbase" ||
|
||||
type === "oracle" ||
|
||||
type === "sqlserver" ||
|
||||
type === "clickhouse" ||
|
||||
type === "mongodb" ||
|
||||
type === "dameng" ||
|
||||
type === "tdengine";
|
||||
|
||||
type DriverStatusSnapshot = {
|
||||
type: string;
|
||||
name: string;
|
||||
@@ -355,12 +377,8 @@ const ConnectionModal: React.FC<{
|
||||
}),
|
||||
[jvmAllowedModes, jvmPreferredMode],
|
||||
);
|
||||
const isMySQLLike =
|
||||
dbType === "mysql" ||
|
||||
dbType === "mariadb" ||
|
||||
dbType === "doris" ||
|
||||
dbType === "diros" ||
|
||||
dbType === "sphinx";
|
||||
const isMySQLLike = isMySQLCompatibleType(dbType);
|
||||
const supportsConnectionParams = supportsConnectionParamsForType(dbType);
|
||||
const isSSLType = supportsSSLForType(dbType);
|
||||
const sslHintText = isMySQLLike
|
||||
? "当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL;本地自签证书场景可先用 Preferred 或 Skip Verify。"
|
||||
@@ -1047,6 +1065,44 @@ const ConnectionModal: React.FC<{
|
||||
return text === "1" || text === "true" || text === "yes" || text === "on";
|
||||
};
|
||||
|
||||
const normalizeConnectionParamsText = (raw: unknown) => {
|
||||
let text = String(raw || "").trim();
|
||||
if (!text) return "";
|
||||
const queryIndex = text.indexOf("?");
|
||||
if (queryIndex >= 0) {
|
||||
text = text.slice(queryIndex + 1);
|
||||
}
|
||||
const hashIndex = text.indexOf("#");
|
||||
if (hashIndex >= 0) {
|
||||
text = text.slice(0, hashIndex);
|
||||
}
|
||||
return text.replace(/^[?&]+/, "").trim().slice(0, MAX_CONNECTION_PARAMS_LENGTH);
|
||||
};
|
||||
|
||||
const serializeConnectionParams = (params: URLSearchParams) => {
|
||||
const cloned = new URLSearchParams();
|
||||
params.forEach((value, key) => {
|
||||
if (String(key || "").trim()) {
|
||||
cloned.append(key, value);
|
||||
}
|
||||
});
|
||||
return cloned.toString().slice(0, MAX_CONNECTION_PARAMS_LENGTH);
|
||||
};
|
||||
|
||||
const mergeConnectionParams = (
|
||||
params: URLSearchParams,
|
||||
rawParams: unknown,
|
||||
) => {
|
||||
const text = normalizeConnectionParamsText(rawParams);
|
||||
if (!text) return;
|
||||
const extra = new URLSearchParams(text);
|
||||
extra.forEach((value, key) => {
|
||||
if (String(key || "").trim()) {
|
||||
params.set(key, value);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const normalizeFileDbPath = (rawPath: string): string => {
|
||||
let pathText = String(rawPath || "").trim();
|
||||
if (!pathText) {
|
||||
@@ -1199,6 +1255,7 @@ const ConnectionModal: React.FC<{
|
||||
clickHouseProtocol: "http",
|
||||
useSSL: isHttps,
|
||||
sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable",
|
||||
connectionParams: serializeConnectionParams(parsed.params),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1214,15 +1271,11 @@ const ConnectionModal: React.FC<{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (
|
||||
type === "mysql" ||
|
||||
type === "mariadb" ||
|
||||
type === "diros" ||
|
||||
type === "sphinx"
|
||||
) {
|
||||
if (isMySQLCompatibleType(type)) {
|
||||
const mysqlDefaultPort = getDefaultPortByType(type);
|
||||
const parsed =
|
||||
parseMultiHostUri(trimmedUri, "mysql") ||
|
||||
parseMultiHostUri(trimmedUri, "jdbc:mysql") ||
|
||||
parseMultiHostUri(trimmedUri, "diros") ||
|
||||
parseMultiHostUri(trimmedUri, "doris");
|
||||
if (!parsed) {
|
||||
@@ -1246,7 +1299,9 @@ const ConnectionModal: React.FC<{
|
||||
const topology = String(
|
||||
parsed.params.get("topology") || "",
|
||||
).toLowerCase();
|
||||
const tlsValue = String(parsed.params.get("tls") || "")
|
||||
const tlsValue = String(
|
||||
parsed.params.get("tls") || parsed.params.get("useSSL") || "",
|
||||
)
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
const sslMode =
|
||||
@@ -1268,6 +1323,7 @@ const ConnectionModal: React.FC<{
|
||||
mysqlTopology:
|
||||
hostList.length > 1 || topology === "replica" ? "replica" : "single",
|
||||
mysqlReplicaHosts: hostList.slice(1),
|
||||
connectionParams: serializeConnectionParams(parsed.params),
|
||||
timeout:
|
||||
Number.isFinite(timeoutValue) && timeoutValue > 0
|
||||
? Math.min(3600, Math.trunc(timeoutValue))
|
||||
@@ -1414,6 +1470,7 @@ const ConnectionModal: React.FC<{
|
||||
mongoAuthSource: parsed.params.get("authSource") || "",
|
||||
mongoReadPreference: parsed.params.get("readPreference") || "primary",
|
||||
mongoAuthMechanism: parsed.params.get("authMechanism") || "",
|
||||
connectionParams: serializeConnectionParams(parsed.params),
|
||||
timeout:
|
||||
Number.isFinite(timeoutMs) && timeoutMs > 0
|
||||
? Math.min(MAX_TIMEOUT_SECONDS, Math.ceil(timeoutMs / 1000))
|
||||
@@ -1450,6 +1507,9 @@ const ConnectionModal: React.FC<{
|
||||
password: parsed.password,
|
||||
database: parsed.database,
|
||||
};
|
||||
if (supportsConnectionParamsForType(type)) {
|
||||
parsedValues.connectionParams = serializeConnectionParams(parsed.params);
|
||||
}
|
||||
|
||||
if (supportsSSLForType(type)) {
|
||||
const normalizeBool = (raw: unknown) => {
|
||||
@@ -1619,12 +1679,7 @@ const ConnectionModal: React.FC<{
|
||||
});
|
||||
|
||||
const getUriPlaceholder = () => {
|
||||
if (
|
||||
dbType === "mysql" ||
|
||||
dbType === "mariadb" ||
|
||||
dbType === "diros" ||
|
||||
dbType === "sphinx"
|
||||
) {
|
||||
if (isMySQLCompatibleType(dbType)) {
|
||||
const defaultPort = getDefaultPortByType(dbType);
|
||||
const scheme = dbType === "diros" ? "doris" : "mysql";
|
||||
return `${scheme}://user:pass@127.0.0.1:${defaultPort},127.0.0.2:${defaultPort}/db_name?topology=replica`;
|
||||
@@ -1649,6 +1704,33 @@ const ConnectionModal: React.FC<{
|
||||
return "例如: postgres://user:pass@127.0.0.1:5432/db_name";
|
||||
};
|
||||
|
||||
const getConnectionParamsPlaceholder = () => {
|
||||
if (isMySQLCompatibleType(dbType)) {
|
||||
return "useUnicode=true&characterEncoding=utf8&autoReconnect=true&useSSL=false";
|
||||
}
|
||||
switch (dbType) {
|
||||
case "postgres":
|
||||
case "kingbase":
|
||||
case "highgo":
|
||||
case "vastbase":
|
||||
return "application_name=GoNavi&statement_timeout=30000";
|
||||
case "oracle":
|
||||
return "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc";
|
||||
case "sqlserver":
|
||||
return "app name=GoNavi&packet size=32767";
|
||||
case "clickhouse":
|
||||
return "max_execution_time=60&compress=lz4";
|
||||
case "mongodb":
|
||||
return "retryWrites=true&readPreference=secondaryPreferred";
|
||||
case "dameng":
|
||||
return "schema=SYSDBA&escapeProcess=true";
|
||||
case "tdengine":
|
||||
return "timezone=Asia%2FShanghai";
|
||||
default:
|
||||
return "key=value&another=value";
|
||||
}
|
||||
};
|
||||
|
||||
const buildUriFromValues = (values: any) => {
|
||||
const type = String(values.type || "")
|
||||
.trim()
|
||||
@@ -1664,12 +1746,7 @@ const ConnectionModal: React.FC<{
|
||||
? `${encodeURIComponent(user)}${password ? `:${encodeURIComponent(password)}` : ""}@`
|
||||
: "";
|
||||
|
||||
if (
|
||||
type === "mysql" ||
|
||||
type === "mariadb" ||
|
||||
type === "diros" ||
|
||||
type === "sphinx"
|
||||
) {
|
||||
if (isMySQLCompatibleType(type)) {
|
||||
const primary = toAddress(host, port, defaultPort);
|
||||
const replicas =
|
||||
values.mysqlTopology === "replica"
|
||||
@@ -1695,6 +1772,7 @@ const ConnectionModal: React.FC<{
|
||||
if (Number.isFinite(timeout) && timeout > 0) {
|
||||
params.set("timeout", String(timeout));
|
||||
}
|
||||
mergeConnectionParams(params, values.connectionParams);
|
||||
const dbPath = database ? `/${encodeURIComponent(database)}` : "/";
|
||||
const query = params.toString();
|
||||
const scheme = type === "diros" ? "doris" : "mysql";
|
||||
@@ -1797,6 +1875,7 @@ const ConnectionModal: React.FC<{
|
||||
params.set("connectTimeoutMS", String(timeout * 1000));
|
||||
params.set("serverSelectionTimeoutMS", String(timeout * 1000));
|
||||
}
|
||||
mergeConnectionParams(params, values.connectionParams);
|
||||
const dbPath = database ? `/${encodeURIComponent(database)}` : "/";
|
||||
const query = params.toString();
|
||||
return `${scheme}://${encodedAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`;
|
||||
@@ -1876,6 +1955,9 @@ const ConnectionModal: React.FC<{
|
||||
if (type === "clickhouse" && clickHouseProtocol !== "auto") {
|
||||
params.set("protocol", clickHouseProtocol);
|
||||
}
|
||||
if (supportsConnectionParamsForType(type)) {
|
||||
mergeConnectionParams(params, values.connectionParams);
|
||||
}
|
||||
const query = params.toString();
|
||||
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ""}`;
|
||||
};
|
||||
@@ -1909,7 +1991,13 @@ const ConnectionModal: React.FC<{
|
||||
});
|
||||
return;
|
||||
}
|
||||
form.setFieldsValue({ ...parsedValues, uri: uriText });
|
||||
form.setFieldsValue(
|
||||
mergeParsedUriValuesForForm(
|
||||
form.getFieldsValue(true),
|
||||
parsedValues,
|
||||
uriText,
|
||||
),
|
||||
);
|
||||
if (testResult) {
|
||||
setTestResult(null);
|
||||
}
|
||||
@@ -2082,6 +2170,11 @@ const ConnectionModal: React.FC<{
|
||||
password: config.password,
|
||||
database: config.database,
|
||||
uri: config.uri || "",
|
||||
connectionParams:
|
||||
config.connectionParams ||
|
||||
(config.uri
|
||||
? parseUriToValues(config.uri, configType)?.connectionParams || ""
|
||||
: ""),
|
||||
clickHouseProtocol:
|
||||
configType === "clickhouse"
|
||||
? normalizeClickHouseProtocolValue(config.clickHouseProtocol)
|
||||
@@ -2294,11 +2387,7 @@ const ConnectionModal: React.FC<{
|
||||
forceClear: !config.useHttpTunnel,
|
||||
});
|
||||
const mysqlReplicaEnabled =
|
||||
(config.type === "mysql" ||
|
||||
config.type === "mariadb" ||
|
||||
config.type === "diros" ||
|
||||
config.type === "sphinx") &&
|
||||
config.topology === "replica";
|
||||
isMySQLCompatibleType(config.type) && config.topology === "replica";
|
||||
const mysqlReplicaDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasMySQLReplicaPassword,
|
||||
valueInput: config.mysqlReplicaPassword,
|
||||
@@ -2528,10 +2617,7 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
if (
|
||||
clearSecrets.mysqlReplicaPassword &&
|
||||
(values.type === "mysql" ||
|
||||
values.type === "mariadb" ||
|
||||
values.type === "diros" ||
|
||||
values.type === "sphinx") &&
|
||||
isMySQLCompatibleType(values.type) &&
|
||||
values.mysqlTopology === "replica" &&
|
||||
String(values.mysqlReplicaPassword ?? "") === ""
|
||||
) {
|
||||
@@ -2983,12 +3069,7 @@ const ConnectionModal: React.FC<{
|
||||
const savePassword =
|
||||
type === "mongodb" ? mergedValues.savePassword !== false : true;
|
||||
|
||||
if (
|
||||
type === "mysql" ||
|
||||
type === "mariadb" ||
|
||||
type === "diros" ||
|
||||
type === "sphinx"
|
||||
) {
|
||||
if (isMySQLCompatibleType(type)) {
|
||||
const replicas =
|
||||
mergedValues.mysqlTopology === "replica"
|
||||
? normalizeAddressList(mergedValues.mysqlReplicaHosts, defaultPort)
|
||||
@@ -3150,6 +3231,9 @@ const ConnectionModal: React.FC<{
|
||||
httpTunnel: httpTunnelConfig,
|
||||
driver: mergedValues.driver,
|
||||
dsn: mergedValues.dsn,
|
||||
connectionParams: supportsConnectionParamsForType(type)
|
||||
? normalizeConnectionParamsText(mergedValues.connectionParams)
|
||||
: "",
|
||||
timeout: Number(mergedValues.timeout || 30),
|
||||
redisDB: Number.isFinite(Number(mergedValues.redisDB))
|
||||
? Math.max(0, Math.min(15, Math.trunc(Number(mergedValues.redisDB))))
|
||||
@@ -3226,6 +3310,7 @@ const ConnectionModal: React.FC<{
|
||||
httpTunnelPassword: "",
|
||||
timeout: 30,
|
||||
uri: "",
|
||||
connectionParams: "",
|
||||
includeDatabases: undefined,
|
||||
includeRedisDatabases: undefined,
|
||||
mysqlTopology: "single",
|
||||
@@ -3303,6 +3388,7 @@ const ConnectionModal: React.FC<{
|
||||
mongoReplicaUser: "",
|
||||
mongoReplicaPassword: "",
|
||||
redisDB: 0,
|
||||
connectionParams: "",
|
||||
});
|
||||
} else if (type !== "custom") {
|
||||
const defaultUser =
|
||||
@@ -3340,6 +3426,7 @@ const ConnectionModal: React.FC<{
|
||||
mongoReplicaUser: "",
|
||||
mongoReplicaPassword: "",
|
||||
redisDB: 0,
|
||||
connectionParams: "",
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3749,6 +3836,19 @@ const ConnectionModal: React.FC<{
|
||||
placeholder={getUriPlaceholder()}
|
||||
/>
|
||||
</Form.Item>
|
||||
{supportsConnectionParams && (
|
||||
<Form.Item
|
||||
name="connectionParams"
|
||||
label="额外连接参数"
|
||||
help="按当前数据源驱动支持的 URI/DSN query 格式填写;认证密码请使用上方密码字段。"
|
||||
>
|
||||
<Input.TextArea
|
||||
{...noAutoCapInputProps}
|
||||
rows={2}
|
||||
placeholder={getConnectionParamsPlaceholder()}
|
||||
/>
|
||||
</Form.Item>
|
||||
)}
|
||||
<Space
|
||||
size={8}
|
||||
style={{ marginBottom: uriFeedback ? 12 : 16 }}
|
||||
@@ -5926,6 +6026,7 @@ const ConnectionModal: React.FC<{
|
||||
httpTunnelPort: 8080,
|
||||
timeout: 30,
|
||||
uri: "",
|
||||
connectionParams: "",
|
||||
mysqlTopology: "single",
|
||||
redisTopology: "single",
|
||||
mongoTopology: "single",
|
||||
@@ -5971,7 +6072,11 @@ const ConnectionModal: React.FC<{
|
||||
setTestResult(null);
|
||||
setTestErrorLogOpen(false);
|
||||
}
|
||||
if (changed.uri !== undefined || changed.type !== undefined) {
|
||||
if (
|
||||
changed.uri !== undefined ||
|
||||
changed.connectionParams !== undefined ||
|
||||
changed.type !== undefined
|
||||
) {
|
||||
setUriFeedback(null);
|
||||
}
|
||||
if (changed.useSSL !== undefined) {
|
||||
|
||||
@@ -490,6 +490,10 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
useHttpTunnel,
|
||||
httpTunnel,
|
||||
uri: toTrimmedString(raw.uri).slice(0, MAX_URI_LENGTH),
|
||||
connectionParams: toTrimmedString(raw.connectionParams).slice(
|
||||
0,
|
||||
MAX_URI_LENGTH,
|
||||
),
|
||||
hosts: sanitizeAddressList(raw.hosts),
|
||||
topology:
|
||||
raw.topology === "replica"
|
||||
|
||||
@@ -294,6 +294,7 @@ export interface ConnectionConfig {
|
||||
httpTunnel?: HTTPTunnelConfig;
|
||||
driver?: string;
|
||||
dsn?: string;
|
||||
connectionParams?: string;
|
||||
timeout?: number;
|
||||
redisDB?: number; // Redis database index (0-15)
|
||||
uri?: string; // Connection URI for copy/paste
|
||||
|
||||
@@ -52,6 +52,19 @@ describe('buildRpcConnectionConfig', () => {
|
||||
expect(result.clickHouseProtocol).toBe('http');
|
||||
});
|
||||
|
||||
it('preserves extra connection params for RPC calls', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-mysql',
|
||||
type: 'mysql',
|
||||
host: 'db.local',
|
||||
port: 3306,
|
||||
user: 'root',
|
||||
connectionParams: 'characterEncoding=utf8&useSSL=false',
|
||||
} as any);
|
||||
|
||||
expect(result.connectionParams).toBe('characterEncoding=utf8&useSSL=false');
|
||||
});
|
||||
|
||||
it('fills default nested config blocks needed by RPC calls', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-redis',
|
||||
|
||||
77
frontend/src/utils/connectionUriMerge.test.ts
Normal file
77
frontend/src/utils/connectionUriMerge.test.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { mergeParsedUriValuesForForm } from "./connectionUriMerge";
|
||||
|
||||
describe("mergeParsedUriValuesForForm", () => {
|
||||
it("keeps saved credentials when parsed URI has no auth section", () => {
|
||||
const result = mergeParsedUriValuesForForm(
|
||||
{
|
||||
user: "root",
|
||||
password: "saved-password",
|
||||
host: "192.168.1.10",
|
||||
port: 3306,
|
||||
database: "old_db",
|
||||
connectionParams: "application_name=GoNavi",
|
||||
timeout: 30,
|
||||
},
|
||||
{
|
||||
host: "192.168.1.240",
|
||||
port: 3306,
|
||||
user: "",
|
||||
password: "",
|
||||
database: "mkefu_location_dev_local",
|
||||
connectionParams: "",
|
||||
timeout: undefined,
|
||||
useSSL: false,
|
||||
},
|
||||
"jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?characterEncoding=UTF-8",
|
||||
);
|
||||
|
||||
expect(result).toMatchObject({
|
||||
uri: "jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?characterEncoding=UTF-8",
|
||||
host: "192.168.1.240",
|
||||
port: 3306,
|
||||
database: "mkefu_location_dev_local",
|
||||
useSSL: false,
|
||||
});
|
||||
expect(result).not.toHaveProperty("user");
|
||||
expect(result).not.toHaveProperty("password");
|
||||
expect(result).not.toHaveProperty("connectionParams");
|
||||
expect(result).not.toHaveProperty("timeout");
|
||||
});
|
||||
|
||||
it("allows URI credentials to replace existing credentials when provided", () => {
|
||||
const result = mergeParsedUriValuesForForm(
|
||||
{
|
||||
user: "root",
|
||||
password: "old-password",
|
||||
},
|
||||
{
|
||||
user: "uri_user",
|
||||
password: "uri-password",
|
||||
},
|
||||
"mysql://uri_user:uri-password@127.0.0.1:3306/app",
|
||||
);
|
||||
|
||||
expect(result).toMatchObject({
|
||||
user: "uri_user",
|
||||
password: "uri-password",
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps existing database when URI omits a database path", () => {
|
||||
const result = mergeParsedUriValuesForForm(
|
||||
{
|
||||
database: "saved_db",
|
||||
},
|
||||
{
|
||||
host: "127.0.0.1",
|
||||
database: "",
|
||||
},
|
||||
"mysql://127.0.0.1:3306",
|
||||
);
|
||||
|
||||
expect(result.database).toBeUndefined();
|
||||
expect(result.host).toBe("127.0.0.1");
|
||||
});
|
||||
});
|
||||
36
frontend/src/utils/connectionUriMerge.ts
Normal file
36
frontend/src/utils/connectionUriMerge.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
const EMPTY_PRESERVED_URI_FIELDS = new Set([
|
||||
"user",
|
||||
"password",
|
||||
"database",
|
||||
"connectionParams",
|
||||
]);
|
||||
|
||||
const isEmptyParsedValue = (value: unknown): boolean =>
|
||||
value === undefined ||
|
||||
value === null ||
|
||||
value === "" ||
|
||||
(Array.isArray(value) && value.length === 0);
|
||||
|
||||
export const mergeParsedUriValuesForForm = (
|
||||
currentValues: Record<string, unknown>,
|
||||
parsedValues: Record<string, unknown>,
|
||||
uriText: string,
|
||||
): Record<string, unknown> => {
|
||||
const nextValues: Record<string, unknown> = { uri: uriText };
|
||||
|
||||
Object.entries(parsedValues).forEach(([key, value]) => {
|
||||
if (value === undefined) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
EMPTY_PRESERVED_URI_FIELDS.has(key) &&
|
||||
isEmptyParsedValue(value) &&
|
||||
!isEmptyParsedValue(currentValues[key])
|
||||
) {
|
||||
return;
|
||||
}
|
||||
nextValues[key] = value;
|
||||
});
|
||||
|
||||
return nextValues;
|
||||
};
|
||||
@@ -667,6 +667,7 @@ export namespace connection {
|
||||
httpTunnel?: HTTPTunnelConfig;
|
||||
driver?: string;
|
||||
dsn?: string;
|
||||
connectionParams?: string;
|
||||
timeout?: number;
|
||||
redisDB?: number;
|
||||
uri?: string;
|
||||
@@ -710,6 +711,7 @@ export namespace connection {
|
||||
this.httpTunnel = this.convertValues(source["httpTunnel"], HTTPTunnelConfig);
|
||||
this.driver = source["driver"];
|
||||
this.dsn = source["dsn"];
|
||||
this.connectionParams = source["connectionParams"];
|
||||
this.timeout = source["timeout"];
|
||||
this.redisDB = source["redisDB"];
|
||||
this.uri = source["uri"];
|
||||
|
||||
@@ -209,6 +209,7 @@ func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.Conn
|
||||
normalized.User = ""
|
||||
normalized.Password = ""
|
||||
normalized.URI = ""
|
||||
normalized.ConnectionParams = ""
|
||||
normalized.Hosts = nil
|
||||
normalized.Topology = ""
|
||||
normalized.MySQLReplicaUser = ""
|
||||
@@ -450,6 +451,9 @@ func formatConnSummary(config connection.ConnectionConfig) string {
|
||||
if strings.TrimSpace(config.URI) != "" {
|
||||
b.WriteString(fmt.Sprintf(" URI=已配置(长度=%d)", len(config.URI)))
|
||||
}
|
||||
if strings.TrimSpace(config.ConnectionParams) != "" {
|
||||
b.WriteString(fmt.Sprintf(" 连接参数=已配置(长度=%d)", len(config.ConnectionParams)))
|
||||
}
|
||||
if strings.TrimSpace(config.MySQLReplicaUser) != "" {
|
||||
b.WriteString(" MySQL从库凭据=已配置")
|
||||
}
|
||||
|
||||
@@ -81,6 +81,26 @@ func TestGetCacheKey_KeepDatabaseIsolation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheKey_KeepConnectionParamsIsolation(t *testing.T) {
|
||||
base := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Password: "root",
|
||||
Database: "app",
|
||||
ConnectionParams: "charset=utf8",
|
||||
}
|
||||
modified := base
|
||||
modified.ConnectionParams = "charset=utf8mb4"
|
||||
|
||||
left := getCacheKey(base)
|
||||
right := getCacheKey(modified)
|
||||
if left == right {
|
||||
t.Fatalf("expected different cache key for different connection params")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheKey_KeepClickHouseProtocolIsolation(t *testing.T) {
|
||||
base := connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
|
||||
@@ -99,6 +99,7 @@ type ConnectionConfig struct {
|
||||
HTTPTunnel HTTPTunnelConfig `json:"httpTunnel,omitempty"`
|
||||
Driver string `json:"driver,omitempty"` // For custom connection
|
||||
DSN string `json:"dsn,omitempty"` // For custom connection
|
||||
ConnectionParams string `json:"connectionParams,omitempty"` // Extra URI query parameters for built-in drivers
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
|
||||
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
|
||||
URI string `json:"uri,omitempty"` // Connection URI for copy/paste
|
||||
|
||||
@@ -106,6 +106,12 @@ func applyClickHouseEndpointURI(config connection.ConnectionConfig, uriText stri
|
||||
if queryProtocol := normalizeClickHouseProtocol(parsed.Query().Get("protocol")); queryProtocol != clickHouseProtocolAuto {
|
||||
config.ClickHouseProtocol = queryProtocol
|
||||
}
|
||||
if parsed.RawQuery != "" {
|
||||
params := url.Values{}
|
||||
mergeConnectionParamValues(params, parsed.Query())
|
||||
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
|
||||
config.ConnectionParams = params.Encode()
|
||||
}
|
||||
endpointProtocol := normalizeClickHouseProtocol(config.ClickHouseProtocol)
|
||||
if isClickHouseHTTPURLScheme(scheme) && endpointProtocol != clickHouseProtocolNative {
|
||||
config.ClickHouseProtocol = clickHouseProtocolHTTP
|
||||
@@ -184,9 +190,148 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
|
||||
if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil {
|
||||
opts.TLS = tlsConfig
|
||||
}
|
||||
applyClickHouseConnectionParams(opts, config)
|
||||
return opts
|
||||
}
|
||||
|
||||
func parseClickHouseDurationParam(raw string) (time.Duration, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return 0, false
|
||||
}
|
||||
if n, err := strconv.Atoi(text); err == nil && n >= 0 {
|
||||
return time.Duration(n) * time.Second, true
|
||||
}
|
||||
duration, err := time.ParseDuration(text)
|
||||
return duration, err == nil
|
||||
}
|
||||
|
||||
func parseClickHouseIntParam(raw string) (int, bool) {
|
||||
n, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
return n, err == nil
|
||||
}
|
||||
|
||||
func clickHouseSettingValue(raw string) any {
|
||||
text := strings.TrimSpace(raw)
|
||||
switch strings.ToLower(text) {
|
||||
case "true", "yes", "on":
|
||||
return int(1)
|
||||
case "false", "no", "off":
|
||||
return int(0)
|
||||
}
|
||||
if n, err := strconv.Atoi(text); err == nil {
|
||||
return n
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func applyClickHouseCompressionParam(opts *clickhouse.Options, raw string) {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" || value == "false" || value == "0" || value == "none" {
|
||||
opts.Compression = &clickhouse.Compression{Method: clickhouse.CompressionNone}
|
||||
return
|
||||
}
|
||||
if opts.Compression == nil {
|
||||
opts.Compression = &clickhouse.Compression{Level: 3}
|
||||
}
|
||||
switch value {
|
||||
case "true", "1", "lz4":
|
||||
opts.Compression.Method = clickhouse.CompressionLZ4
|
||||
case "zstd":
|
||||
opts.Compression.Method = clickhouse.CompressionZSTD
|
||||
case "lz4hc":
|
||||
opts.Compression.Method = clickhouse.CompressionLZ4HC
|
||||
case "gzip":
|
||||
opts.Compression.Method = clickhouse.CompressionGZIP
|
||||
case "deflate":
|
||||
opts.Compression.Method = clickhouse.CompressionDeflate
|
||||
case "br", "brotli":
|
||||
opts.Compression.Method = clickhouse.CompressionBrotli
|
||||
}
|
||||
}
|
||||
|
||||
func applyClickHouseConnectionParams(opts *clickhouse.Options, config connection.ConnectionConfig) {
|
||||
params := url.Values{}
|
||||
mergeConnectionParamsFromConfig(params, config, "clickhouse", "http", "https")
|
||||
if len(params) == 0 {
|
||||
return
|
||||
}
|
||||
if opts.Settings == nil {
|
||||
opts.Settings = clickhouse.Settings{}
|
||||
}
|
||||
keys := make([]string, 0, len(params))
|
||||
for key := range params {
|
||||
if strings.TrimSpace(key) != "" {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
values := params[key]
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
value := values[len(values)-1]
|
||||
switch strings.ToLower(strings.TrimSpace(key)) {
|
||||
case "protocol", "secure", "skip_verify", "username", "password", "database":
|
||||
continue
|
||||
case "dial_timeout":
|
||||
if duration, ok := parseClickHouseDurationParam(value); ok {
|
||||
opts.DialTimeout = duration
|
||||
}
|
||||
case "read_timeout":
|
||||
if duration, ok := parseClickHouseDurationParam(value); ok {
|
||||
opts.ReadTimeout = duration
|
||||
}
|
||||
case "compress":
|
||||
applyClickHouseCompressionParam(opts, value)
|
||||
case "compress_level":
|
||||
if level, ok := parseClickHouseIntParam(value); ok {
|
||||
if opts.Compression == nil {
|
||||
opts.Compression = &clickhouse.Compression{Method: clickhouse.CompressionNone}
|
||||
}
|
||||
opts.Compression.Level = level
|
||||
}
|
||||
case "max_open_conns":
|
||||
if n, ok := parseClickHouseIntParam(value); ok {
|
||||
opts.MaxOpenConns = n
|
||||
}
|
||||
case "max_idle_conns":
|
||||
if n, ok := parseClickHouseIntParam(value); ok {
|
||||
opts.MaxIdleConns = n
|
||||
}
|
||||
case "max_compression_buffer":
|
||||
if n, ok := parseClickHouseIntParam(value); ok {
|
||||
opts.MaxCompressionBuffer = n
|
||||
}
|
||||
case "block_buffer_size":
|
||||
if n, ok := parseClickHouseIntParam(value); ok && n > 0 && n <= 255 {
|
||||
opts.BlockBufferSize = uint8(n)
|
||||
}
|
||||
case "http_path":
|
||||
path := strings.TrimSpace(value)
|
||||
if path != "" && !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
opts.HttpUrlPath = path
|
||||
case "connection_open_strategy":
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "in_order":
|
||||
opts.ConnOpenStrategy = clickhouse.ConnOpenInOrder
|
||||
case "round_robin":
|
||||
opts.ConnOpenStrategy = clickhouse.ConnOpenRoundRobin
|
||||
case "random":
|
||||
opts.ConnOpenStrategy = clickhouse.ConnOpenRandom
|
||||
}
|
||||
default:
|
||||
opts.Settings[key] = clickHouseSettingValue(value)
|
||||
}
|
||||
}
|
||||
if len(opts.Settings) == 0 {
|
||||
opts.Settings = nil
|
||||
}
|
||||
}
|
||||
|
||||
func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Protocol {
|
||||
switch normalizeClickHouseProtocol(config.ClickHouseProtocol) {
|
||||
case clickHouseProtocolHTTP:
|
||||
|
||||
117
internal/db/connection_params.go
Normal file
117
internal/db/connection_params.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func parseConnectionURI(raw string, allowedSchemes ...string) (*url.URL, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return nil, false
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(text), "jdbc:") {
|
||||
text = strings.TrimSpace(text[len("jdbc:"):])
|
||||
}
|
||||
parsed, err := url.Parse(text)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
for _, allowed := range allowedSchemes {
|
||||
if scheme == strings.ToLower(strings.TrimSpace(allowed)) {
|
||||
return parsed, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func connectionParamsFromText(raw string) url.Values {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
if queryIndex := strings.Index(text, "?"); queryIndex >= 0 {
|
||||
text = text[queryIndex+1:]
|
||||
}
|
||||
if hashIndex := strings.Index(text, "#"); hashIndex >= 0 {
|
||||
text = text[:hashIndex]
|
||||
}
|
||||
text = strings.TrimLeft(strings.TrimSpace(text), "?&")
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
values, err := url.ParseQuery(text)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func connectionParamsFromURI(raw string, allowedSchemes ...string) url.Values {
|
||||
parsed, ok := parseConnectionURI(raw, allowedSchemes...)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return parsed.Query()
|
||||
}
|
||||
|
||||
func mergeConnectionParamValues(params url.Values, values url.Values) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
if strings.TrimSpace(key) != "" {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
for _, value := range values[key] {
|
||||
params.Set(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeConnectionParamsFromConfig(params url.Values, config connection.ConnectionConfig, allowedSchemes ...string) {
|
||||
mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, allowedSchemes...))
|
||||
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
|
||||
}
|
||||
|
||||
func mergeConnectionParamsIntoRawURI(raw string, connectionParams string, allowedSchemes ...string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
parsed, ok := parseConnectionURI(text, allowedSchemes...)
|
||||
if !ok {
|
||||
return text
|
||||
}
|
||||
params := parsed.Query()
|
||||
mergeConnectionParamValues(params, connectionParamsFromText(connectionParams))
|
||||
parsed.RawQuery = params.Encode()
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func isSafeConnectionParamKey(key string) bool {
|
||||
text := strings.TrimSpace(key)
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range text {
|
||||
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' {
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '_', '-', '.', ' ':
|
||||
continue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -48,6 +48,7 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// 达梦驱动要求:密码包含特殊字符时,password 需 PathEscape,并添加 escapeProcess=true 让驱动解码。
|
||||
q.Set("escapeProcess", "true")
|
||||
}
|
||||
mergeConnectionParamsFromConfig(q, config, "dm", "dameng")
|
||||
|
||||
dsn := fmt.Sprintf("dm://%s:%s@%s", config.User, escapedPassword, address)
|
||||
encoded := q.Encode()
|
||||
|
||||
@@ -5,7 +5,6 @@ package db
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -40,15 +39,8 @@ func applyDirosURI(config connection.ConnectionConfig) connection.ConnectionConf
|
||||
return config
|
||||
}
|
||||
|
||||
lowerURI := strings.ToLower(uriText)
|
||||
if !strings.HasPrefix(lowerURI, "diros://") &&
|
||||
!strings.HasPrefix(lowerURI, "doris://") &&
|
||||
!strings.HasPrefix(lowerURI, "mysql://") {
|
||||
return config
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(uriText)
|
||||
if err != nil {
|
||||
parsed, ok := parseMySQLCompatibleURI(uriText, "diros", "doris", "mysql")
|
||||
if !ok {
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -147,13 +139,7 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
|
||||
), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
}
|
||||
|
||||
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
|
||||
clickhouse "github.com/ClickHouse/clickhouse-go/v2"
|
||||
)
|
||||
|
||||
func TestPostgresDSN_EscapesPassword(t *testing.T) {
|
||||
@@ -52,6 +55,32 @@ func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
|
||||
p := &PostgresDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5432,
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
ConnectionParams: "application_name=GoNavi&connect_timeout=9",
|
||||
}
|
||||
|
||||
dsn := p.getDSN(cfg)
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("parse postgres dsn: %v", err)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("application_name"); got != "GoNavi" {
|
||||
t.Fatalf("application_name = %q, want GoNavi", got)
|
||||
}
|
||||
if got := query.Get("connect_timeout"); got != "9" {
|
||||
t.Fatalf("connect_timeout = %q, want 9", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
|
||||
m := &MySQLDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
@@ -65,7 +94,10 @@ func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
|
||||
SSLMode: "required",
|
||||
}
|
||||
|
||||
dsn := m.getDSN(cfg)
|
||||
dsn, err := m.getDSN(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(dsn, "tls=true") {
|
||||
t.Fatalf("dsn 缺少 tls=true 参数:%s", dsn)
|
||||
}
|
||||
@@ -161,6 +193,27 @@ func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestKingbaseDSN_MergesConnectionParams(t *testing.T) {
|
||||
k := &KingbaseDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "kingbase",
|
||||
Host: "127.0.0.1",
|
||||
Port: 54321,
|
||||
User: "system",
|
||||
Password: "pass",
|
||||
Database: "TEST",
|
||||
ConnectionParams: "application_name=GoNavi&connect_timeout=12",
|
||||
}
|
||||
|
||||
dsn := k.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "application_name=GoNavi") {
|
||||
t.Fatalf("dsn 缺少 application_name:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "connect_timeout=12") {
|
||||
t.Fatalf("dsn 缺少自定义 connect_timeout:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
|
||||
td := &TDengineDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
@@ -197,6 +250,24 @@ func TestTDengineDSN_UsesSecureWebSocketWhenSSLEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDengineDSN_MergesConnectionParams(t *testing.T) {
|
||||
td := &TDengineDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "tdengine",
|
||||
Host: "127.0.0.1",
|
||||
Port: 6041,
|
||||
User: "root",
|
||||
Password: "taosdata",
|
||||
Database: "power",
|
||||
ConnectionParams: "timezone=Asia%2FShanghai&protocol=wss",
|
||||
}
|
||||
|
||||
dsn := td.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "?timezone=Asia%2FShanghai") {
|
||||
t.Fatalf("tdengine dsn 缺少自定义参数或错误透传 protocol:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLServerDSN_EncryptMapping(t *testing.T) {
|
||||
s := &SqlServerDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
@@ -219,6 +290,32 @@ func TestSQLServerDSN_EncryptMapping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
|
||||
s := &SqlServerDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "sqlserver",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1433,
|
||||
User: "sa",
|
||||
Password: "pass",
|
||||
Database: "master",
|
||||
ConnectionParams: "app name=GoNavi&packet size=32767",
|
||||
}
|
||||
|
||||
dsn := s.getDSN(cfg)
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("parse sqlserver dsn: %v", err)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("app name"); got != "GoNavi" {
|
||||
t.Fatalf("app name = %q, want GoNavi", got)
|
||||
}
|
||||
if got := query.Get("packet size"); got != "32767" {
|
||||
t.Fatalf("packet size = %q, want 32767", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
|
||||
c := &ClickHouseDB{}
|
||||
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
@@ -264,6 +361,34 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseOptions_MergesConnectionParamsIntoOptionsAndSettings(t *testing.T) {
|
||||
c := &ClickHouseDB{}
|
||||
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
Type: "clickhouse",
|
||||
Host: "127.0.0.1",
|
||||
Port: 9000,
|
||||
User: "default",
|
||||
Password: "secret",
|
||||
Database: "analytics",
|
||||
Timeout: 15,
|
||||
ConnectionParams: "max_execution_time=60&compress=lz4&read_timeout=10s",
|
||||
})
|
||||
|
||||
opts := c.buildClickHouseOptions(cfg)
|
||||
if opts == nil {
|
||||
t.Fatal("options 为空")
|
||||
}
|
||||
if opts.ReadTimeout != 10*time.Second {
|
||||
t.Fatalf("read timeout 不符合预期:%s", opts.ReadTimeout)
|
||||
}
|
||||
if opts.Compression == nil || opts.Compression.Method != clickhouse.CompressionLZ4 {
|
||||
t.Fatalf("compression 不符合预期:%v", opts.Compression)
|
||||
}
|
||||
if got := opts.Settings["max_execution_time"]; got != 60 {
|
||||
t.Fatalf("max_execution_time = %#v, want 60", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseOptions_ReadTimeoutUsesLargerConfiguredTimeout(t *testing.T) {
|
||||
c := &ClickHouseDB{}
|
||||
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
|
||||
@@ -44,6 +44,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "highgo")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -62,21 +64,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// Kingbase DSN usually similar to Postgres:
|
||||
// host=localhost port=54321 user=system password=... dbname=TEST sslmode=disable
|
||||
|
||||
address := config.Host
|
||||
port := config.Port
|
||||
params := url.Values{}
|
||||
params.Set("host", config.Host)
|
||||
params.Set("port", strconv.Itoa(config.Port))
|
||||
params.Set("user", config.User)
|
||||
params.Set("password", config.Password)
|
||||
params.Set("dbname", config.Database)
|
||||
params.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(params, config, "kingbase")
|
||||
|
||||
// Construct DSN
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d",
|
||||
quoteConnValue(address),
|
||||
port,
|
||||
quoteConnValue(config.User),
|
||||
quoteConnValue(config.Password),
|
||||
quoteConnValue(config.Database),
|
||||
quoteConnValue(resolvePostgresSSLMode(config)),
|
||||
getConnectTimeoutSeconds(config),
|
||||
)
|
||||
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"}
|
||||
seen := make(map[string]struct{}, len(params))
|
||||
parts := make([]string, 0, len(params))
|
||||
for _, key := range preferred {
|
||||
if values, ok := params[key]; ok && len(values) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1])))
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
extraKeys := make([]string, 0, len(params))
|
||||
for key := range params {
|
||||
if _, ok := seen[key]; ok || !isSafeConnectionParamKey(key) {
|
||||
continue
|
||||
}
|
||||
extraKeys = append(extraKeys, key)
|
||||
}
|
||||
sort.Strings(extraKeys)
|
||||
for _, key := range extraKeys {
|
||||
values := params[key]
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1])))
|
||||
}
|
||||
|
||||
return dsn
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -37,17 +36,12 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
|
||||
), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
}
|
||||
|
||||
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {
|
||||
dsn, err := m.getDSN(config)
|
||||
runConfig := applyMySQLURI(config)
|
||||
dsn, err := m.getDSN(runConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf
|
||||
|
||||
func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
|
||||
if strings.TrimSpace(config.URI) != "" {
|
||||
return strings.TrimSpace(config.URI)
|
||||
return mergeConnectionParamsIntoRawURI(config.URI, config.ConnectionParams, "mongodb", "mongodb+srv")
|
||||
}
|
||||
|
||||
seeds := collectMongoSeeds(config)
|
||||
@@ -257,6 +257,7 @@ func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
|
||||
if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" && !noAuth {
|
||||
params.Set("authMechanism", authMechanism)
|
||||
}
|
||||
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
|
||||
|
||||
// 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现
|
||||
if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -37,3 +38,30 @@ func TestApplyMongoURI_ExplicitHostsDoesNotAdoptURIHosts(t *testing.T) {
|
||||
t.Fatalf("expected explicit hosts to stay untouched, got %v", got.Hosts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMongoURI_MergesConnectionParams(t *testing.T) {
|
||||
uri := (&MongoDB{}).getURI(connection.ConnectionConfig{
|
||||
Host: "mongo.local",
|
||||
Port: 27017,
|
||||
Database: "app",
|
||||
ConnectionParams: "retryWrites=true&readPreference=secondaryPreferred",
|
||||
})
|
||||
|
||||
if !strings.Contains(uri, "retryWrites=true") {
|
||||
t.Fatalf("uri 缺少 retryWrites 参数:%s", uri)
|
||||
}
|
||||
if !strings.Contains(uri, "readPreference=secondaryPreferred") {
|
||||
t.Fatalf("uri 缺少 readPreference 参数:%s", uri)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMongoURI_MergesConnectionParamsIntoExistingURI(t *testing.T) {
|
||||
uri := (&MongoDB{}).getURI(connection.ConnectionConfig{
|
||||
URI: "mongodb://mongo.local:27017/app?authSource=admin",
|
||||
ConnectionParams: "retryWrites=true",
|
||||
})
|
||||
|
||||
if !strings.Contains(uri, "authSource=admin") || !strings.Contains(uri, "retryWrites=true") {
|
||||
t.Fatalf("uri 未合并已有 URI query 与额外参数:%s", uri)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,7 +193,7 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf
|
||||
|
||||
func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string {
|
||||
if strings.TrimSpace(config.URI) != "" {
|
||||
return strings.TrimSpace(config.URI)
|
||||
return mergeConnectionParamsIntoRawURI(config.URI, config.ConnectionParams, "mongodb", "mongodb+srv")
|
||||
}
|
||||
|
||||
seeds := collectMongoSeeds(config)
|
||||
@@ -258,6 +258,7 @@ func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string {
|
||||
if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" && !noAuth {
|
||||
params.Set("authMechanism", authMechanism)
|
||||
}
|
||||
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
|
||||
|
||||
// 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现
|
||||
if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV {
|
||||
|
||||
153
internal/db/mysql_connection_params_test.go
Normal file
153
internal/db/mysql_connection_params_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values {
|
||||
t.Helper()
|
||||
parts := strings.SplitN(dsn, "?", 2)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("dsn missing query: %s", dsn)
|
||||
}
|
||||
values, err := url.ParseQuery(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("parse dsn query: %v", err)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "db.local",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Password: "secret",
|
||||
Database: "app",
|
||||
Timeout: 30,
|
||||
ConnectionParams: "charset=utf8&readTimeout=10&columnsWithAlias=true",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
if got := query.Get("charset"); got != "utf8" {
|
||||
t.Fatalf("charset should be overridden by connectionParams, got=%q", got)
|
||||
}
|
||||
if got := query.Get("readTimeout"); got != "10s" {
|
||||
t.Fatalf("numeric readTimeout should be converted to duration, got=%q", got)
|
||||
}
|
||||
if got := query.Get("columnsWithAlias"); got != "true" {
|
||||
t.Fatalf("driver-specific parameter should be preserved, got=%q", got)
|
||||
}
|
||||
if got := query.Get("multiStatements"); got != "true" {
|
||||
t.Fatalf("default multiStatements should remain enabled, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "192.168.1.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "app",
|
||||
ConnectionParams: "useUnicode=true&characterEncoding=utf8&autoReconnect=true&" +
|
||||
"useSSL=false&verifyServerCertificate=false&useOldAliasMetadataBehavior=true",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
if got := query.Get("charset"); got != "utf8" {
|
||||
t.Fatalf("characterEncoding should map to charset, got=%q", got)
|
||||
}
|
||||
if got := query.Get("tls"); got != "false" {
|
||||
t.Fatalf("useSSL=false should map to tls=false, got=%q", got)
|
||||
}
|
||||
for _, forbidden := range []string{
|
||||
"useUnicode",
|
||||
"characterEncoding",
|
||||
"autoReconnect",
|
||||
"useSSL",
|
||||
"verifyServerCertificate",
|
||||
"useOldAliasMetadataBehavior",
|
||||
} {
|
||||
if _, exists := query[forbidden]; exists {
|
||||
t.Fatalf("JDBC-only parameter %s should not be passed to Go MySQL driver: %v", forbidden, query)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_MapsJDBCUTF8EncodingToMySQLCharset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "192.168.1.240",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "mkefu_location_dev_local",
|
||||
URI: "jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?" +
|
||||
"useUnicode=true&characterEncoding=UTF-8&serverTimezone=GMT%2B8",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
if got := query.Get("charset"); got != "utf8mb4" {
|
||||
t.Fatalf("JDBC characterEncoding=UTF-8 should map to MySQL charset utf8mb4, got=%q", got)
|
||||
}
|
||||
if got := query.Get("characterEncoding"); got != "" {
|
||||
t.Fatalf("JDBC characterEncoding should not be passed to Go MySQL driver, got=%q", got)
|
||||
}
|
||||
if got := query.Get("serverTimezone"); got != "" {
|
||||
t.Fatalf("JDBC serverTimezone should not be passed to Go MySQL driver, got=%q", got)
|
||||
}
|
||||
if got := query.Get("loc"); got != "Asia%2FShanghai" && got != "Asia/Shanghai" {
|
||||
t.Fatalf("serverTimezone=GMT+8 should map to loc=Asia/Shanghai, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_URIParamsAndExplicitParamsPrecedence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "db.local",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "app",
|
||||
URI: "jdbc:mysql://db.local:3306/app?characterEncoding=utf8&timeout=15&topology=replica&useSSL=false",
|
||||
ConnectionParams: "charset=utf8mb4&timeout=5s&socketTimeout=45000",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
if got := query.Get("charset"); got != "utf8mb4" {
|
||||
t.Fatalf("connectionParams should override URI charset, got=%q", got)
|
||||
}
|
||||
if got := query.Get("timeout"); got != "5s" {
|
||||
t.Fatalf("connectionParams should override URI timeout, got=%q", got)
|
||||
}
|
||||
if got := query.Get("readTimeout"); got != "45s" {
|
||||
t.Fatalf("socketTimeout should map to readTimeout duration, got=%q", got)
|
||||
}
|
||||
if _, exists := query["topology"]; exists {
|
||||
t.Fatalf("internal topology parameter should not be passed to driver: %v", query)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -26,6 +27,183 @@ type MySQLDB struct {
|
||||
|
||||
const defaultMySQLPort = 3306
|
||||
|
||||
func parseMySQLCompatibleURI(raw string, allowedSchemes ...string) (*url.URL, bool) {
|
||||
return parseConnectionURI(raw, allowedSchemes...)
|
||||
}
|
||||
|
||||
func mysqlConnectionParamsFromText(raw string) url.Values {
|
||||
return connectionParamsFromText(raw)
|
||||
}
|
||||
|
||||
func parseMySQLBoolParam(raw string) (bool, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true, true
|
||||
case "0", "false", "no", "off":
|
||||
return false, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMySQLDurationParam(raw string, unit time.Duration) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
if n, err := strconv.Atoi(text); err == nil && n >= 0 {
|
||||
return (time.Duration(n) * unit).String()
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func normalizeMySQLCharsetParam(raw string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
switch lower {
|
||||
case "utf-8", "utf_8", "unicode":
|
||||
return "utf8mb4"
|
||||
case "utf8", "utf8mb4", "latin1", "gbk", "gb2312", "gb18030", "big5", "sjis", "cp932":
|
||||
return lower
|
||||
case "iso-8859-1", "iso8859-1", "iso88591":
|
||||
return "latin1"
|
||||
default:
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMySQLServerTimezoneParam(raw string) (string, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return "", false
|
||||
}
|
||||
compact := strings.ToUpper(strings.ReplaceAll(text, " ", ""))
|
||||
switch compact {
|
||||
case "LOCAL":
|
||||
return "Local", true
|
||||
case "UTC", "Z", "GMT", "GMT+0", "GMT-0", "GMT+00", "GMT-00", "GMT+00:00", "GMT-00:00",
|
||||
"UTC+0", "UTC-0", "UTC+00", "UTC-00", "UTC+00:00", "UTC-00:00":
|
||||
return "UTC", true
|
||||
case "GMT+8", "GMT+08", "GMT+08:00", "UTC+8", "UTC+08", "UTC+08:00",
|
||||
"ASIA/SHANGHAI", "PRC", "CTT":
|
||||
return "Asia/Shanghai", true
|
||||
}
|
||||
if strings.Contains(text, "/") {
|
||||
if _, err := time.LoadLocation(text); err == nil {
|
||||
return text, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func mergeMySQLConnectionParam(params url.Values, key string, value string) {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
lowerName := strings.ToLower(name)
|
||||
switch lowerName {
|
||||
case "topology":
|
||||
return
|
||||
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior":
|
||||
return
|
||||
case "charset":
|
||||
if charset := normalizeMySQLCharsetParam(value); charset != "" {
|
||||
params.Set("charset", charset)
|
||||
}
|
||||
return
|
||||
case "characterencoding":
|
||||
if charset := normalizeMySQLCharsetParam(value); charset != "" {
|
||||
params.Set("charset", charset)
|
||||
}
|
||||
return
|
||||
case "servertimezone":
|
||||
if loc, ok := normalizeMySQLServerTimezoneParam(value); ok {
|
||||
params.Set("loc", loc)
|
||||
}
|
||||
return
|
||||
case "usessl":
|
||||
if enabled, ok := parseMySQLBoolParam(value); ok {
|
||||
if enabled {
|
||||
params.Set("tls", "true")
|
||||
} else {
|
||||
params.Set("tls", "false")
|
||||
}
|
||||
}
|
||||
return
|
||||
case "verifyservercertificate":
|
||||
if verified, ok := parseMySQLBoolParam(value); ok && !verified && params.Get("tls") != "false" {
|
||||
params.Set("tls", "skip-verify")
|
||||
}
|
||||
return
|
||||
case "trustservercertificate":
|
||||
if trusted, ok := parseMySQLBoolParam(value); ok && trusted && params.Get("tls") != "false" {
|
||||
params.Set("tls", "skip-verify")
|
||||
}
|
||||
return
|
||||
case "connecttimeout":
|
||||
params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond))
|
||||
return
|
||||
case "sockettimeout":
|
||||
params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond))
|
||||
return
|
||||
case "timeout", "readtimeout", "writetimeout":
|
||||
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
|
||||
return
|
||||
default:
|
||||
params.Set(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeMySQLConnectionParams(params url.Values, values url.Values) {
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
lowerName := strings.ToLower(strings.TrimSpace(key))
|
||||
if lowerName == "verifyservercertificate" || lowerName == "trustservercertificate" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values[key] {
|
||||
mergeMySQLConnectionParam(params, key, value)
|
||||
}
|
||||
}
|
||||
for _, key := range keys {
|
||||
lowerName := strings.ToLower(strings.TrimSpace(key))
|
||||
if lowerName != "verifyservercertificate" && lowerName != "trustservercertificate" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values[key] {
|
||||
mergeMySQLConnectionParam(params, key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string {
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
params := url.Values{}
|
||||
params.Set("charset", "utf8mb4")
|
||||
params.Set("parseTime", "True")
|
||||
params.Set("loc", "Local")
|
||||
params.Set("timeout", fmt.Sprintf("%ds", timeout))
|
||||
params.Set("tls", tlsMode)
|
||||
params.Set("multiStatements", "true")
|
||||
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros"); ok {
|
||||
mergeMySQLConnectionParams(params, parsed.Query())
|
||||
}
|
||||
mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams))
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?%s",
|
||||
config.User, config.Password, protocol, address, database, params.Encode(),
|
||||
)
|
||||
}
|
||||
|
||||
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
@@ -135,13 +313,8 @@ func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConf
|
||||
if uriText == "" {
|
||||
return config
|
||||
}
|
||||
lowerURI := strings.ToLower(uriText)
|
||||
if !strings.HasPrefix(lowerURI, "mysql://") {
|
||||
return config
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(uriText)
|
||||
if err != nil {
|
||||
parsed, ok := parseMySQLCompatibleURI(uriText, "mysql")
|
||||
if !ok {
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -239,13 +412,7 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
|
||||
), nil
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||||
}
|
||||
|
||||
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
|
||||
@@ -30,3 +30,28 @@ func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) {
|
||||
t.Fatalf("LOB FETCH = %q, want POST", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 1521,
|
||||
User: "scott",
|
||||
Password: "tiger",
|
||||
Database: "ORCLPDB1",
|
||||
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc",
|
||||
})
|
||||
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 Oracle DSN 失败: %v", err)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("PREFETCH_ROWS"); got != "5000" {
|
||||
t.Fatalf("PREFETCH_ROWS = %q, want 5000", got)
|
||||
}
|
||||
if got := query.Get("TRACE FILE"); got != "/tmp/go-ora.trc" {
|
||||
t.Fatalf("TRACE FILE = %q, want /tmp/go-ora.trc", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q.Set("PREFETCH_ROWS", "10000")
|
||||
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
|
||||
q.Set("LOB FETCH", "POST")
|
||||
mergeConnectionParamsFromConfig(q, config, "oracle")
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
u.RawQuery = encoded
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
@@ -50,6 +50,7 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
|
||||
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
|
||||
q.Set("encrypt", encrypt)
|
||||
q.Set("TrustServerCertificate", trustServerCertificate)
|
||||
mergeConnectionParamsFromConfig(q, config, "sqlserver")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -42,7 +43,16 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
netType := resolveTDengineNet(config)
|
||||
return fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
|
||||
params := url.Values{}
|
||||
mergeConnectionParamsFromConfig(params, config, "taos", "taosws", "tdengine")
|
||||
params.Del("protocol")
|
||||
params.Del("skip_verify")
|
||||
query := params.Encode()
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
|
||||
if query == "" {
|
||||
return dsn
|
||||
}
|
||||
return dsn + "?" + query
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
@@ -43,6 +43,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "vastbase")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
Reference in New Issue
Block a user