feat(connection): 支持多数据源额外连接参数配置

- 前端连接表单新增额外连接参数入口,支持 URI query 格式录入与解析回填
- MySQL 兼容驱动支持 JDBC 常见参数映射,修复 UTF-8 字符集与 serverTimezone 兼容问题
- 扩展 Oracle、PostgreSQL 兼容、SQL Server、ClickHouse、MongoDB、达梦、TDengine 参数合并
- 按不同驱动通道处理 DSN、URI、Options 与 Settings,避免统一透传导致连接异常
- 修复编辑已保存连接时解析无认证 URI 会清空已有账号密码的问题
- 补充连接参数透传、缓存隔离、DSN 合并与 URI 回填回归测试
This commit is contained in:
Syngnat
2026-04-30 10:57:52 +08:00
parent c65e429072
commit c92959f3e8
29 changed files with 1143 additions and 99 deletions

View File

@@ -61,6 +61,7 @@ import {
} from "../utils/connectionModalPresentation";
import { resolveConnectionSecretDraft } from "../utils/connectionSecretDraft";
import { getCustomConnectionDsnValidationMessage } from "../utils/customConnectionDsn";
import { mergeParsedUriValuesForForm } from "../utils/connectionUriMerge";
import { CUSTOM_CONNECTION_DRIVER_HELP } from "../utils/driverImportGuidance";
import {
applyNoAutoCapAttributes,
@@ -97,6 +98,7 @@ type ChoiceCardOption = {
};
type ClickHouseProtocolChoice = "auto" | "http" | "native";
const MAX_URI_LENGTH = 4096;
const MAX_CONNECTION_PARAMS_LENGTH = 4096;
const MAX_URI_HOSTS = 32;
const MAX_TIMEOUT_SECONDS = 3600;
const CONNECTION_MODAL_WIDTH = 960;
@@ -232,6 +234,26 @@ const supportsSSLForType = (type: string) =>
const isFileDatabaseType = (type: string) =>
type === "sqlite" || type === "duckdb";
const isMySQLCompatibleType = (type: string) =>
type === "mysql" ||
type === "mariadb" ||
type === "doris" ||
type === "diros" ||
type === "sphinx";
const supportsConnectionParamsForType = (type: string) =>
isMySQLCompatibleType(type) ||
type === "postgres" ||
type === "kingbase" ||
type === "highgo" ||
type === "vastbase" ||
type === "oracle" ||
type === "sqlserver" ||
type === "clickhouse" ||
type === "mongodb" ||
type === "dameng" ||
type === "tdengine";
type DriverStatusSnapshot = {
type: string;
name: string;
@@ -355,12 +377,8 @@ const ConnectionModal: React.FC<{
}),
[jvmAllowedModes, jvmPreferredMode],
);
const isMySQLLike =
dbType === "mysql" ||
dbType === "mariadb" ||
dbType === "doris" ||
dbType === "diros" ||
dbType === "sphinx";
const isMySQLLike = isMySQLCompatibleType(dbType);
const supportsConnectionParams = supportsConnectionParamsForType(dbType);
const isSSLType = supportsSSLForType(dbType);
const sslHintText = isMySQLLike
? "当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL本地自签证书场景可先用 Preferred 或 Skip Verify。"
@@ -1047,6 +1065,44 @@ const ConnectionModal: React.FC<{
return text === "1" || text === "true" || text === "yes" || text === "on";
};
const normalizeConnectionParamsText = (raw: unknown) => {
let text = String(raw || "").trim();
if (!text) return "";
const queryIndex = text.indexOf("?");
if (queryIndex >= 0) {
text = text.slice(queryIndex + 1);
}
const hashIndex = text.indexOf("#");
if (hashIndex >= 0) {
text = text.slice(0, hashIndex);
}
return text.replace(/^[?&]+/, "").trim().slice(0, MAX_CONNECTION_PARAMS_LENGTH);
};
const serializeConnectionParams = (params: URLSearchParams) => {
const cloned = new URLSearchParams();
params.forEach((value, key) => {
if (String(key || "").trim()) {
cloned.append(key, value);
}
});
return cloned.toString().slice(0, MAX_CONNECTION_PARAMS_LENGTH);
};
const mergeConnectionParams = (
params: URLSearchParams,
rawParams: unknown,
) => {
const text = normalizeConnectionParamsText(rawParams);
if (!text) return;
const extra = new URLSearchParams(text);
extra.forEach((value, key) => {
if (String(key || "").trim()) {
params.set(key, value);
}
});
};
const normalizeFileDbPath = (rawPath: string): string => {
let pathText = String(rawPath || "").trim();
if (!pathText) {
@@ -1199,6 +1255,7 @@ const ConnectionModal: React.FC<{
clickHouseProtocol: "http",
useSSL: isHttps,
sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable",
connectionParams: serializeConnectionParams(parsed.params),
};
};
@@ -1214,15 +1271,11 @@ const ConnectionModal: React.FC<{
return null;
}
if (
type === "mysql" ||
type === "mariadb" ||
type === "diros" ||
type === "sphinx"
) {
if (isMySQLCompatibleType(type)) {
const mysqlDefaultPort = getDefaultPortByType(type);
const parsed =
parseMultiHostUri(trimmedUri, "mysql") ||
parseMultiHostUri(trimmedUri, "jdbc:mysql") ||
parseMultiHostUri(trimmedUri, "diros") ||
parseMultiHostUri(trimmedUri, "doris");
if (!parsed) {
@@ -1246,7 +1299,9 @@ const ConnectionModal: React.FC<{
const topology = String(
parsed.params.get("topology") || "",
).toLowerCase();
const tlsValue = String(parsed.params.get("tls") || "")
const tlsValue = String(
parsed.params.get("tls") || parsed.params.get("useSSL") || "",
)
.trim()
.toLowerCase();
const sslMode =
@@ -1268,6 +1323,7 @@ const ConnectionModal: React.FC<{
mysqlTopology:
hostList.length > 1 || topology === "replica" ? "replica" : "single",
mysqlReplicaHosts: hostList.slice(1),
connectionParams: serializeConnectionParams(parsed.params),
timeout:
Number.isFinite(timeoutValue) && timeoutValue > 0
? Math.min(3600, Math.trunc(timeoutValue))
@@ -1414,6 +1470,7 @@ const ConnectionModal: React.FC<{
mongoAuthSource: parsed.params.get("authSource") || "",
mongoReadPreference: parsed.params.get("readPreference") || "primary",
mongoAuthMechanism: parsed.params.get("authMechanism") || "",
connectionParams: serializeConnectionParams(parsed.params),
timeout:
Number.isFinite(timeoutMs) && timeoutMs > 0
? Math.min(MAX_TIMEOUT_SECONDS, Math.ceil(timeoutMs / 1000))
@@ -1450,6 +1507,9 @@ const ConnectionModal: React.FC<{
password: parsed.password,
database: parsed.database,
};
if (supportsConnectionParamsForType(type)) {
parsedValues.connectionParams = serializeConnectionParams(parsed.params);
}
if (supportsSSLForType(type)) {
const normalizeBool = (raw: unknown) => {
@@ -1619,12 +1679,7 @@ const ConnectionModal: React.FC<{
});
const getUriPlaceholder = () => {
if (
dbType === "mysql" ||
dbType === "mariadb" ||
dbType === "diros" ||
dbType === "sphinx"
) {
if (isMySQLCompatibleType(dbType)) {
const defaultPort = getDefaultPortByType(dbType);
const scheme = dbType === "diros" ? "doris" : "mysql";
return `${scheme}://user:pass@127.0.0.1:${defaultPort},127.0.0.2:${defaultPort}/db_name?topology=replica`;
@@ -1649,6 +1704,33 @@ const ConnectionModal: React.FC<{
return "例如: postgres://user:pass@127.0.0.1:5432/db_name";
};
const getConnectionParamsPlaceholder = () => {
if (isMySQLCompatibleType(dbType)) {
return "useUnicode=true&characterEncoding=utf8&autoReconnect=true&useSSL=false";
}
switch (dbType) {
case "postgres":
case "kingbase":
case "highgo":
case "vastbase":
return "application_name=GoNavi&statement_timeout=30000";
case "oracle":
return "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc";
case "sqlserver":
return "app name=GoNavi&packet size=32767";
case "clickhouse":
return "max_execution_time=60&compress=lz4";
case "mongodb":
return "retryWrites=true&readPreference=secondaryPreferred";
case "dameng":
return "schema=SYSDBA&escapeProcess=true";
case "tdengine":
return "timezone=Asia%2FShanghai";
default:
return "key=value&another=value";
}
};
const buildUriFromValues = (values: any) => {
const type = String(values.type || "")
.trim()
@@ -1664,12 +1746,7 @@ const ConnectionModal: React.FC<{
? `${encodeURIComponent(user)}${password ? `:${encodeURIComponent(password)}` : ""}@`
: "";
if (
type === "mysql" ||
type === "mariadb" ||
type === "diros" ||
type === "sphinx"
) {
if (isMySQLCompatibleType(type)) {
const primary = toAddress(host, port, defaultPort);
const replicas =
values.mysqlTopology === "replica"
@@ -1695,6 +1772,7 @@ const ConnectionModal: React.FC<{
if (Number.isFinite(timeout) && timeout > 0) {
params.set("timeout", String(timeout));
}
mergeConnectionParams(params, values.connectionParams);
const dbPath = database ? `/${encodeURIComponent(database)}` : "/";
const query = params.toString();
const scheme = type === "diros" ? "doris" : "mysql";
@@ -1797,6 +1875,7 @@ const ConnectionModal: React.FC<{
params.set("connectTimeoutMS", String(timeout * 1000));
params.set("serverSelectionTimeoutMS", String(timeout * 1000));
}
mergeConnectionParams(params, values.connectionParams);
const dbPath = database ? `/${encodeURIComponent(database)}` : "/";
const query = params.toString();
return `${scheme}://${encodedAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`;
@@ -1876,6 +1955,9 @@ const ConnectionModal: React.FC<{
if (type === "clickhouse" && clickHouseProtocol !== "auto") {
params.set("protocol", clickHouseProtocol);
}
if (supportsConnectionParamsForType(type)) {
mergeConnectionParams(params, values.connectionParams);
}
const query = params.toString();
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ""}`;
};
@@ -1909,7 +1991,13 @@ const ConnectionModal: React.FC<{
});
return;
}
form.setFieldsValue({ ...parsedValues, uri: uriText });
form.setFieldsValue(
mergeParsedUriValuesForForm(
form.getFieldsValue(true),
parsedValues,
uriText,
),
);
if (testResult) {
setTestResult(null);
}
@@ -2082,6 +2170,11 @@ const ConnectionModal: React.FC<{
password: config.password,
database: config.database,
uri: config.uri || "",
connectionParams:
config.connectionParams ||
(config.uri
? parseUriToValues(config.uri, configType)?.connectionParams || ""
: ""),
clickHouseProtocol:
configType === "clickhouse"
? normalizeClickHouseProtocolValue(config.clickHouseProtocol)
@@ -2294,11 +2387,7 @@ const ConnectionModal: React.FC<{
forceClear: !config.useHttpTunnel,
});
const mysqlReplicaEnabled =
(config.type === "mysql" ||
config.type === "mariadb" ||
config.type === "diros" ||
config.type === "sphinx") &&
config.topology === "replica";
isMySQLCompatibleType(config.type) && config.topology === "replica";
const mysqlReplicaDraft = resolveConnectionSecretDraft({
hasSecret: initialValues?.hasMySQLReplicaPassword,
valueInput: config.mysqlReplicaPassword,
@@ -2528,10 +2617,7 @@ const ConnectionModal: React.FC<{
}
if (
clearSecrets.mysqlReplicaPassword &&
(values.type === "mysql" ||
values.type === "mariadb" ||
values.type === "diros" ||
values.type === "sphinx") &&
isMySQLCompatibleType(values.type) &&
values.mysqlTopology === "replica" &&
String(values.mysqlReplicaPassword ?? "") === ""
) {
@@ -2983,12 +3069,7 @@ const ConnectionModal: React.FC<{
const savePassword =
type === "mongodb" ? mergedValues.savePassword !== false : true;
if (
type === "mysql" ||
type === "mariadb" ||
type === "diros" ||
type === "sphinx"
) {
if (isMySQLCompatibleType(type)) {
const replicas =
mergedValues.mysqlTopology === "replica"
? normalizeAddressList(mergedValues.mysqlReplicaHosts, defaultPort)
@@ -3150,6 +3231,9 @@ const ConnectionModal: React.FC<{
httpTunnel: httpTunnelConfig,
driver: mergedValues.driver,
dsn: mergedValues.dsn,
connectionParams: supportsConnectionParamsForType(type)
? normalizeConnectionParamsText(mergedValues.connectionParams)
: "",
timeout: Number(mergedValues.timeout || 30),
redisDB: Number.isFinite(Number(mergedValues.redisDB))
? Math.max(0, Math.min(15, Math.trunc(Number(mergedValues.redisDB))))
@@ -3226,6 +3310,7 @@ const ConnectionModal: React.FC<{
httpTunnelPassword: "",
timeout: 30,
uri: "",
connectionParams: "",
includeDatabases: undefined,
includeRedisDatabases: undefined,
mysqlTopology: "single",
@@ -3303,6 +3388,7 @@ const ConnectionModal: React.FC<{
mongoReplicaUser: "",
mongoReplicaPassword: "",
redisDB: 0,
connectionParams: "",
});
} else if (type !== "custom") {
const defaultUser =
@@ -3340,6 +3426,7 @@ const ConnectionModal: React.FC<{
mongoReplicaUser: "",
mongoReplicaPassword: "",
redisDB: 0,
connectionParams: "",
});
}
@@ -3749,6 +3836,19 @@ const ConnectionModal: React.FC<{
placeholder={getUriPlaceholder()}
/>
</Form.Item>
{supportsConnectionParams && (
<Form.Item
name="connectionParams"
label="额外连接参数"
help="按当前数据源驱动支持的 URI/DSN query 格式填写;认证密码请使用上方密码字段。"
>
<Input.TextArea
{...noAutoCapInputProps}
rows={2}
placeholder={getConnectionParamsPlaceholder()}
/>
</Form.Item>
)}
<Space
size={8}
style={{ marginBottom: uriFeedback ? 12 : 16 }}
@@ -5926,6 +6026,7 @@ const ConnectionModal: React.FC<{
httpTunnelPort: 8080,
timeout: 30,
uri: "",
connectionParams: "",
mysqlTopology: "single",
redisTopology: "single",
mongoTopology: "single",
@@ -5971,7 +6072,11 @@ const ConnectionModal: React.FC<{
setTestResult(null);
setTestErrorLogOpen(false);
}
if (changed.uri !== undefined || changed.type !== undefined) {
if (
changed.uri !== undefined ||
changed.connectionParams !== undefined ||
changed.type !== undefined
) {
setUriFeedback(null);
}
if (changed.useSSL !== undefined) {

View File

@@ -490,6 +490,10 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
useHttpTunnel,
httpTunnel,
uri: toTrimmedString(raw.uri).slice(0, MAX_URI_LENGTH),
connectionParams: toTrimmedString(raw.connectionParams).slice(
0,
MAX_URI_LENGTH,
),
hosts: sanitizeAddressList(raw.hosts),
topology:
raw.topology === "replica"

View File

@@ -294,6 +294,7 @@ export interface ConnectionConfig {
httpTunnel?: HTTPTunnelConfig;
driver?: string;
dsn?: string;
connectionParams?: string;
timeout?: number;
redisDB?: number; // Redis database index (0-15)
uri?: string; // Connection URI for copy/paste

View File

@@ -52,6 +52,19 @@ describe('buildRpcConnectionConfig', () => {
expect(result.clickHouseProtocol).toBe('http');
});
it('preserves extra connection params for RPC calls', () => {
const result = buildRpcConnectionConfig({
id: 'conn-mysql',
type: 'mysql',
host: 'db.local',
port: 3306,
user: 'root',
connectionParams: 'characterEncoding=utf8&useSSL=false',
} as any);
expect(result.connectionParams).toBe('characterEncoding=utf8&useSSL=false');
});
it('fills default nested config blocks needed by RPC calls', () => {
const result = buildRpcConnectionConfig({
id: 'conn-redis',

View File

@@ -0,0 +1,77 @@
import { describe, expect, it } from "vitest";
import { mergeParsedUriValuesForForm } from "./connectionUriMerge";
describe("mergeParsedUriValuesForForm", () => {
it("keeps saved credentials when parsed URI has no auth section", () => {
const result = mergeParsedUriValuesForForm(
{
user: "root",
password: "saved-password",
host: "192.168.1.10",
port: 3306,
database: "old_db",
connectionParams: "application_name=GoNavi",
timeout: 30,
},
{
host: "192.168.1.240",
port: 3306,
user: "",
password: "",
database: "mkefu_location_dev_local",
connectionParams: "",
timeout: undefined,
useSSL: false,
},
"jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?characterEncoding=UTF-8",
);
expect(result).toMatchObject({
uri: "jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?characterEncoding=UTF-8",
host: "192.168.1.240",
port: 3306,
database: "mkefu_location_dev_local",
useSSL: false,
});
expect(result).not.toHaveProperty("user");
expect(result).not.toHaveProperty("password");
expect(result).not.toHaveProperty("connectionParams");
expect(result).not.toHaveProperty("timeout");
});
it("allows URI credentials to replace existing credentials when provided", () => {
const result = mergeParsedUriValuesForForm(
{
user: "root",
password: "old-password",
},
{
user: "uri_user",
password: "uri-password",
},
"mysql://uri_user:uri-password@127.0.0.1:3306/app",
);
expect(result).toMatchObject({
user: "uri_user",
password: "uri-password",
});
});
it("keeps existing database when URI omits a database path", () => {
const result = mergeParsedUriValuesForForm(
{
database: "saved_db",
},
{
host: "127.0.0.1",
database: "",
},
"mysql://127.0.0.1:3306",
);
expect(result.database).toBeUndefined();
expect(result.host).toBe("127.0.0.1");
});
});

View File

@@ -0,0 +1,36 @@
const EMPTY_PRESERVED_URI_FIELDS = new Set([
"user",
"password",
"database",
"connectionParams",
]);
const isEmptyParsedValue = (value: unknown): boolean =>
value === undefined ||
value === null ||
value === "" ||
(Array.isArray(value) && value.length === 0);
export const mergeParsedUriValuesForForm = (
currentValues: Record<string, unknown>,
parsedValues: Record<string, unknown>,
uriText: string,
): Record<string, unknown> => {
const nextValues: Record<string, unknown> = { uri: uriText };
Object.entries(parsedValues).forEach(([key, value]) => {
if (value === undefined) {
return;
}
if (
EMPTY_PRESERVED_URI_FIELDS.has(key) &&
isEmptyParsedValue(value) &&
!isEmptyParsedValue(currentValues[key])
) {
return;
}
nextValues[key] = value;
});
return nextValues;
};

View File

@@ -667,6 +667,7 @@ export namespace connection {
httpTunnel?: HTTPTunnelConfig;
driver?: string;
dsn?: string;
connectionParams?: string;
timeout?: number;
redisDB?: number;
uri?: string;
@@ -710,6 +711,7 @@ export namespace connection {
this.httpTunnel = this.convertValues(source["httpTunnel"], HTTPTunnelConfig);
this.driver = source["driver"];
this.dsn = source["dsn"];
this.connectionParams = source["connectionParams"];
this.timeout = source["timeout"];
this.redisDB = source["redisDB"];
this.uri = source["uri"];

View File

@@ -209,6 +209,7 @@ func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.Conn
normalized.User = ""
normalized.Password = ""
normalized.URI = ""
normalized.ConnectionParams = ""
normalized.Hosts = nil
normalized.Topology = ""
normalized.MySQLReplicaUser = ""
@@ -450,6 +451,9 @@ func formatConnSummary(config connection.ConnectionConfig) string {
if strings.TrimSpace(config.URI) != "" {
b.WriteString(fmt.Sprintf(" URI=已配置(长度=%d)", len(config.URI)))
}
if strings.TrimSpace(config.ConnectionParams) != "" {
b.WriteString(fmt.Sprintf(" 连接参数=已配置(长度=%d)", len(config.ConnectionParams)))
}
if strings.TrimSpace(config.MySQLReplicaUser) != "" {
b.WriteString(" MySQL从库凭据=已配置")
}

View File

@@ -81,6 +81,26 @@ func TestGetCacheKey_KeepDatabaseIsolation(t *testing.T) {
}
}
func TestGetCacheKey_KeepConnectionParamsIsolation(t *testing.T) {
base := connection.ConnectionConfig{
Type: "mysql",
Host: "127.0.0.1",
Port: 3306,
User: "root",
Password: "root",
Database: "app",
ConnectionParams: "charset=utf8",
}
modified := base
modified.ConnectionParams = "charset=utf8mb4"
left := getCacheKey(base)
right := getCacheKey(modified)
if left == right {
t.Fatalf("expected different cache key for different connection params")
}
}
func TestGetCacheKey_KeepClickHouseProtocolIsolation(t *testing.T) {
base := connection.ConnectionConfig{
Type: "clickhouse",

View File

@@ -99,6 +99,7 @@ type ConnectionConfig struct {
HTTPTunnel HTTPTunnelConfig `json:"httpTunnel,omitempty"`
Driver string `json:"driver,omitempty"` // For custom connection
DSN string `json:"dsn,omitempty"` // For custom connection
ConnectionParams string `json:"connectionParams,omitempty"` // Extra URI query parameters for built-in drivers
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30)
RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15)
URI string `json:"uri,omitempty"` // Connection URI for copy/paste

View File

@@ -106,6 +106,12 @@ func applyClickHouseEndpointURI(config connection.ConnectionConfig, uriText stri
if queryProtocol := normalizeClickHouseProtocol(parsed.Query().Get("protocol")); queryProtocol != clickHouseProtocolAuto {
config.ClickHouseProtocol = queryProtocol
}
if parsed.RawQuery != "" {
params := url.Values{}
mergeConnectionParamValues(params, parsed.Query())
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
config.ConnectionParams = params.Encode()
}
endpointProtocol := normalizeClickHouseProtocol(config.ClickHouseProtocol)
if isClickHouseHTTPURLScheme(scheme) && endpointProtocol != clickHouseProtocolNative {
config.ClickHouseProtocol = clickHouseProtocolHTTP
@@ -184,9 +190,148 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil {
opts.TLS = tlsConfig
}
applyClickHouseConnectionParams(opts, config)
return opts
}
func parseClickHouseDurationParam(raw string) (time.Duration, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return 0, false
}
if n, err := strconv.Atoi(text); err == nil && n >= 0 {
return time.Duration(n) * time.Second, true
}
duration, err := time.ParseDuration(text)
return duration, err == nil
}
func parseClickHouseIntParam(raw string) (int, bool) {
n, err := strconv.Atoi(strings.TrimSpace(raw))
return n, err == nil
}
func clickHouseSettingValue(raw string) any {
text := strings.TrimSpace(raw)
switch strings.ToLower(text) {
case "true", "yes", "on":
return int(1)
case "false", "no", "off":
return int(0)
}
if n, err := strconv.Atoi(text); err == nil {
return n
}
return text
}
func applyClickHouseCompressionParam(opts *clickhouse.Options, raw string) {
value := strings.ToLower(strings.TrimSpace(raw))
if value == "" || value == "false" || value == "0" || value == "none" {
opts.Compression = &clickhouse.Compression{Method: clickhouse.CompressionNone}
return
}
if opts.Compression == nil {
opts.Compression = &clickhouse.Compression{Level: 3}
}
switch value {
case "true", "1", "lz4":
opts.Compression.Method = clickhouse.CompressionLZ4
case "zstd":
opts.Compression.Method = clickhouse.CompressionZSTD
case "lz4hc":
opts.Compression.Method = clickhouse.CompressionLZ4HC
case "gzip":
opts.Compression.Method = clickhouse.CompressionGZIP
case "deflate":
opts.Compression.Method = clickhouse.CompressionDeflate
case "br", "brotli":
opts.Compression.Method = clickhouse.CompressionBrotli
}
}
func applyClickHouseConnectionParams(opts *clickhouse.Options, config connection.ConnectionConfig) {
params := url.Values{}
mergeConnectionParamsFromConfig(params, config, "clickhouse", "http", "https")
if len(params) == 0 {
return
}
if opts.Settings == nil {
opts.Settings = clickhouse.Settings{}
}
keys := make([]string, 0, len(params))
for key := range params {
if strings.TrimSpace(key) != "" {
keys = append(keys, key)
}
}
sort.Strings(keys)
for _, key := range keys {
values := params[key]
if len(values) == 0 {
continue
}
value := values[len(values)-1]
switch strings.ToLower(strings.TrimSpace(key)) {
case "protocol", "secure", "skip_verify", "username", "password", "database":
continue
case "dial_timeout":
if duration, ok := parseClickHouseDurationParam(value); ok {
opts.DialTimeout = duration
}
case "read_timeout":
if duration, ok := parseClickHouseDurationParam(value); ok {
opts.ReadTimeout = duration
}
case "compress":
applyClickHouseCompressionParam(opts, value)
case "compress_level":
if level, ok := parseClickHouseIntParam(value); ok {
if opts.Compression == nil {
opts.Compression = &clickhouse.Compression{Method: clickhouse.CompressionNone}
}
opts.Compression.Level = level
}
case "max_open_conns":
if n, ok := parseClickHouseIntParam(value); ok {
opts.MaxOpenConns = n
}
case "max_idle_conns":
if n, ok := parseClickHouseIntParam(value); ok {
opts.MaxIdleConns = n
}
case "max_compression_buffer":
if n, ok := parseClickHouseIntParam(value); ok {
opts.MaxCompressionBuffer = n
}
case "block_buffer_size":
if n, ok := parseClickHouseIntParam(value); ok && n > 0 && n <= 255 {
opts.BlockBufferSize = uint8(n)
}
case "http_path":
path := strings.TrimSpace(value)
if path != "" && !strings.HasPrefix(path, "/") {
path = "/" + path
}
opts.HttpUrlPath = path
case "connection_open_strategy":
switch strings.ToLower(strings.TrimSpace(value)) {
case "in_order":
opts.ConnOpenStrategy = clickhouse.ConnOpenInOrder
case "round_robin":
opts.ConnOpenStrategy = clickhouse.ConnOpenRoundRobin
case "random":
opts.ConnOpenStrategy = clickhouse.ConnOpenRandom
}
default:
opts.Settings[key] = clickHouseSettingValue(value)
}
}
if len(opts.Settings) == 0 {
opts.Settings = nil
}
}
func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Protocol {
switch normalizeClickHouseProtocol(config.ClickHouseProtocol) {
case clickHouseProtocolHTTP:

View File

@@ -0,0 +1,117 @@
package db
import (
"net/url"
"sort"
"strings"
"GoNavi-Wails/internal/connection"
)
func parseConnectionURI(raw string, allowedSchemes ...string) (*url.URL, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return nil, false
}
if strings.HasPrefix(strings.ToLower(text), "jdbc:") {
text = strings.TrimSpace(text[len("jdbc:"):])
}
parsed, err := url.Parse(text)
if err != nil {
return nil, false
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
for _, allowed := range allowedSchemes {
if scheme == strings.ToLower(strings.TrimSpace(allowed)) {
return parsed, true
}
}
return nil, false
}
func connectionParamsFromText(raw string) url.Values {
text := strings.TrimSpace(raw)
if text == "" {
return nil
}
if queryIndex := strings.Index(text, "?"); queryIndex >= 0 {
text = text[queryIndex+1:]
}
if hashIndex := strings.Index(text, "#"); hashIndex >= 0 {
text = text[:hashIndex]
}
text = strings.TrimLeft(strings.TrimSpace(text), "?&")
if text == "" {
return nil
}
values, err := url.ParseQuery(text)
if err != nil {
return nil
}
return values
}
func connectionParamsFromURI(raw string, allowedSchemes ...string) url.Values {
parsed, ok := parseConnectionURI(raw, allowedSchemes...)
if !ok {
return nil
}
return parsed.Query()
}
func mergeConnectionParamValues(params url.Values, values url.Values) {
if len(values) == 0 {
return
}
keys := make([]string, 0, len(values))
for key := range values {
if strings.TrimSpace(key) != "" {
keys = append(keys, key)
}
}
sort.Strings(keys)
for _, key := range keys {
for _, value := range values[key] {
params.Set(key, value)
}
}
}
func mergeConnectionParamsFromConfig(params url.Values, config connection.ConnectionConfig, allowedSchemes ...string) {
mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, allowedSchemes...))
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
}
func mergeConnectionParamsIntoRawURI(raw string, connectionParams string, allowedSchemes ...string) string {
text := strings.TrimSpace(raw)
if text == "" {
return text
}
parsed, ok := parseConnectionURI(text, allowedSchemes...)
if !ok {
return text
}
params := parsed.Query()
mergeConnectionParamValues(params, connectionParamsFromText(connectionParams))
parsed.RawQuery = params.Encode()
return parsed.String()
}
func isSafeConnectionParamKey(key string) bool {
text := strings.TrimSpace(key)
if text == "" {
return false
}
for _, r := range text {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' {
continue
}
switch r {
case '_', '-', '.', ' ':
continue
default:
return false
}
}
return true
}

View File

@@ -48,6 +48,7 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
// 达梦驱动要求密码包含特殊字符时password 需 PathEscape并添加 escapeProcess=true 让驱动解码。
q.Set("escapeProcess", "true")
}
mergeConnectionParamsFromConfig(q, config, "dm", "dameng")
dsn := fmt.Sprintf("dm://%s:%s@%s", config.User, escapedPassword, address)
encoded := q.Encode()

View File

@@ -5,7 +5,6 @@ package db
import (
"database/sql"
"fmt"
"net/url"
"strings"
"GoNavi-Wails/internal/connection"
@@ -40,15 +39,8 @@ func applyDirosURI(config connection.ConnectionConfig) connection.ConnectionConf
return config
}
lowerURI := strings.ToLower(uriText)
if !strings.HasPrefix(lowerURI, "diros://") &&
!strings.HasPrefix(lowerURI, "doris://") &&
!strings.HasPrefix(lowerURI, "mysql://") {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
parsed, ok := parseMySQLCompatibleURI(uriText, "diros", "doris", "mysql")
if !ok {
return config
}
@@ -147,13 +139,7 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) {
protocol = netName
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
), nil
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
}
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -3,11 +3,14 @@
package db
import (
"net/url"
"strings"
"testing"
"time"
"GoNavi-Wails/internal/connection"
clickhouse "github.com/ClickHouse/clickhouse-go/v2"
)
func TestPostgresDSN_EscapesPassword(t *testing.T) {
@@ -52,6 +55,32 @@ func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) {
}
}
func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
p := &PostgresDB{}
cfg := connection.ConnectionConfig{
Type: "postgres",
Host: "127.0.0.1",
Port: 5432,
User: "user",
Password: "pass",
Database: "db",
ConnectionParams: "application_name=GoNavi&connect_timeout=9",
}
dsn := p.getDSN(cfg)
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("parse postgres dsn: %v", err)
}
query := parsed.Query()
if got := query.Get("application_name"); got != "GoNavi" {
t.Fatalf("application_name = %q, want GoNavi", got)
}
if got := query.Get("connect_timeout"); got != "9" {
t.Fatalf("connect_timeout = %q, want 9", got)
}
}
func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
m := &MySQLDB{}
cfg := connection.ConnectionConfig{
@@ -65,7 +94,10 @@ func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
SSLMode: "required",
}
dsn := m.getDSN(cfg)
dsn, err := m.getDSN(cfg)
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
if !strings.Contains(dsn, "tls=true") {
t.Fatalf("dsn 缺少 tls=true 参数:%s", dsn)
}
@@ -161,6 +193,27 @@ func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) {
}
}
func TestKingbaseDSN_MergesConnectionParams(t *testing.T) {
k := &KingbaseDB{}
cfg := connection.ConnectionConfig{
Type: "kingbase",
Host: "127.0.0.1",
Port: 54321,
User: "system",
Password: "pass",
Database: "TEST",
ConnectionParams: "application_name=GoNavi&connect_timeout=12",
}
dsn := k.getDSN(cfg)
if !strings.Contains(dsn, "application_name=GoNavi") {
t.Fatalf("dsn 缺少 application_name%s", dsn)
}
if !strings.Contains(dsn, "connect_timeout=12") {
t.Fatalf("dsn 缺少自定义 connect_timeout%s", dsn)
}
}
func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
td := &TDengineDB{}
cfg := connection.ConnectionConfig{
@@ -197,6 +250,24 @@ func TestTDengineDSN_UsesSecureWebSocketWhenSSLEnabled(t *testing.T) {
}
}
func TestTDengineDSN_MergesConnectionParams(t *testing.T) {
td := &TDengineDB{}
cfg := connection.ConnectionConfig{
Type: "tdengine",
Host: "127.0.0.1",
Port: 6041,
User: "root",
Password: "taosdata",
Database: "power",
ConnectionParams: "timezone=Asia%2FShanghai&protocol=wss",
}
dsn := td.getDSN(cfg)
if !strings.Contains(dsn, "?timezone=Asia%2FShanghai") {
t.Fatalf("tdengine dsn 缺少自定义参数或错误透传 protocol%s", dsn)
}
}
func TestSQLServerDSN_EncryptMapping(t *testing.T) {
s := &SqlServerDB{}
cfg := connection.ConnectionConfig{
@@ -219,6 +290,32 @@ func TestSQLServerDSN_EncryptMapping(t *testing.T) {
}
}
func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
s := &SqlServerDB{}
cfg := connection.ConnectionConfig{
Type: "sqlserver",
Host: "127.0.0.1",
Port: 1433,
User: "sa",
Password: "pass",
Database: "master",
ConnectionParams: "app name=GoNavi&packet size=32767",
}
dsn := s.getDSN(cfg)
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("parse sqlserver dsn: %v", err)
}
query := parsed.Query()
if got := query.Get("app name"); got != "GoNavi" {
t.Fatalf("app name = %q, want GoNavi", got)
}
if got := query.Get("packet size"); got != "32767" {
t.Fatalf("packet size = %q, want 32767", got)
}
}
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
@@ -264,6 +361,34 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
}
}
func TestClickHouseOptions_MergesConnectionParamsIntoOptionsAndSettings(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
Type: "clickhouse",
Host: "127.0.0.1",
Port: 9000,
User: "default",
Password: "secret",
Database: "analytics",
Timeout: 15,
ConnectionParams: "max_execution_time=60&compress=lz4&read_timeout=10s",
})
opts := c.buildClickHouseOptions(cfg)
if opts == nil {
t.Fatal("options 为空")
}
if opts.ReadTimeout != 10*time.Second {
t.Fatalf("read timeout 不符合预期:%s", opts.ReadTimeout)
}
if opts.Compression == nil || opts.Compression.Method != clickhouse.CompressionLZ4 {
t.Fatalf("compression 不符合预期:%v", opts.Compression)
}
if got := opts.Settings["max_execution_time"]; got != 60 {
t.Fatalf("max_execution_time = %#v, want 60", got)
}
}
func TestClickHouseOptions_ReadTimeoutUsesLargerConfiguredTimeout(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{

View File

@@ -44,6 +44,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "highgo")
u.RawQuery = q.Encode()
return u.String()

View File

@@ -7,7 +7,9 @@ import (
"database/sql"
"fmt"
"net"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
@@ -62,21 +64,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
// Kingbase DSN usually similar to Postgres:
// host=localhost port=54321 user=system password=... dbname=TEST sslmode=disable
address := config.Host
port := config.Port
params := url.Values{}
params.Set("host", config.Host)
params.Set("port", strconv.Itoa(config.Port))
params.Set("user", config.User)
params.Set("password", config.Password)
params.Set("dbname", config.Database)
params.Set("sslmode", resolvePostgresSSLMode(config))
params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(params, config, "kingbase")
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d",
quoteConnValue(address),
port,
quoteConnValue(config.User),
quoteConnValue(config.Password),
quoteConnValue(config.Database),
quoteConnValue(resolvePostgresSSLMode(config)),
getConnectTimeoutSeconds(config),
)
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"}
seen := make(map[string]struct{}, len(params))
parts := make([]string, 0, len(params))
for _, key := range preferred {
if values, ok := params[key]; ok && len(values) > 0 {
parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1])))
seen[key] = struct{}{}
}
}
extraKeys := make([]string, 0, len(params))
for key := range params {
if _, ok := seen[key]; ok || !isSafeConnectionParamKey(key) {
continue
}
extraKeys = append(extraKeys, key)
}
sort.Strings(extraKeys)
for _, key := range extraKeys {
values := params[key]
if len(values) == 0 {
continue
}
parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1])))
}
return dsn
return strings.Join(parts, " ")
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {

View File

@@ -6,7 +6,6 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
@@ -37,17 +36,12 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) {
protocol = netName
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
), nil
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
}
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {
dsn, err := m.getDSN(config)
runConfig := applyMySQLURI(config)
dsn, err := m.getDSN(runConfig)
if err != nil {
return err
}

View File

@@ -192,7 +192,7 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf
func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
if strings.TrimSpace(config.URI) != "" {
return strings.TrimSpace(config.URI)
return mergeConnectionParamsIntoRawURI(config.URI, config.ConnectionParams, "mongodb", "mongodb+srv")
}
seeds := collectMongoSeeds(config)
@@ -257,6 +257,7 @@ func (m *MongoDB) getURI(config connection.ConnectionConfig) string {
if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" && !noAuth {
params.Set("authMechanism", authMechanism)
}
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
// 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现
if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV {

View File

@@ -3,6 +3,7 @@
package db
import (
"strings"
"testing"
"GoNavi-Wails/internal/connection"
@@ -37,3 +38,30 @@ func TestApplyMongoURI_ExplicitHostsDoesNotAdoptURIHosts(t *testing.T) {
t.Fatalf("expected explicit hosts to stay untouched, got %v", got.Hosts)
}
}
func TestMongoURI_MergesConnectionParams(t *testing.T) {
uri := (&MongoDB{}).getURI(connection.ConnectionConfig{
Host: "mongo.local",
Port: 27017,
Database: "app",
ConnectionParams: "retryWrites=true&readPreference=secondaryPreferred",
})
if !strings.Contains(uri, "retryWrites=true") {
t.Fatalf("uri 缺少 retryWrites 参数:%s", uri)
}
if !strings.Contains(uri, "readPreference=secondaryPreferred") {
t.Fatalf("uri 缺少 readPreference 参数:%s", uri)
}
}
func TestMongoURI_MergesConnectionParamsIntoExistingURI(t *testing.T) {
uri := (&MongoDB{}).getURI(connection.ConnectionConfig{
URI: "mongodb://mongo.local:27017/app?authSource=admin",
ConnectionParams: "retryWrites=true",
})
if !strings.Contains(uri, "authSource=admin") || !strings.Contains(uri, "retryWrites=true") {
t.Fatalf("uri 未合并已有 URI query 与额外参数:%s", uri)
}
}

View File

@@ -193,7 +193,7 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf
func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string {
if strings.TrimSpace(config.URI) != "" {
return strings.TrimSpace(config.URI)
return mergeConnectionParamsIntoRawURI(config.URI, config.ConnectionParams, "mongodb", "mongodb+srv")
}
seeds := collectMongoSeeds(config)
@@ -258,6 +258,7 @@ func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string {
if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" && !noAuth {
params.Set("authMechanism", authMechanism)
}
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
// 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现
if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV {

View File

@@ -0,0 +1,153 @@
package db
import (
"net/url"
"strings"
"testing"
"GoNavi-Wails/internal/connection"
)
func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values {
t.Helper()
parts := strings.SplitN(dsn, "?", 2)
if len(parts) != 2 {
t.Fatalf("dsn missing query: %s", dsn)
}
values, err := url.ParseQuery(parts[1])
if err != nil {
t.Fatalf("parse dsn query: %v", err)
}
return values
}
func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "db.local",
Port: 3306,
User: "root",
Password: "secret",
Database: "app",
Timeout: 30,
ConnectionParams: "charset=utf8&readTimeout=10&columnsWithAlias=true",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if got := query.Get("charset"); got != "utf8" {
t.Fatalf("charset should be overridden by connectionParams, got=%q", got)
}
if got := query.Get("readTimeout"); got != "10s" {
t.Fatalf("numeric readTimeout should be converted to duration, got=%q", got)
}
if got := query.Get("columnsWithAlias"); got != "true" {
t.Fatalf("driver-specific parameter should be preserved, got=%q", got)
}
if got := query.Get("multiStatements"); got != "true" {
t.Fatalf("default multiStatements should remain enabled, got=%q", got)
}
}
func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "192.168.1.1",
Port: 3306,
User: "root",
Database: "app",
ConnectionParams: "useUnicode=true&characterEncoding=utf8&autoReconnect=true&" +
"useSSL=false&verifyServerCertificate=false&useOldAliasMetadataBehavior=true",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if got := query.Get("charset"); got != "utf8" {
t.Fatalf("characterEncoding should map to charset, got=%q", got)
}
if got := query.Get("tls"); got != "false" {
t.Fatalf("useSSL=false should map to tls=false, got=%q", got)
}
for _, forbidden := range []string{
"useUnicode",
"characterEncoding",
"autoReconnect",
"useSSL",
"verifyServerCertificate",
"useOldAliasMetadataBehavior",
} {
if _, exists := query[forbidden]; exists {
t.Fatalf("JDBC-only parameter %s should not be passed to Go MySQL driver: %v", forbidden, query)
}
}
}
func TestMySQLDSN_MapsJDBCUTF8EncodingToMySQLCharset(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "192.168.1.240",
Port: 3306,
User: "root",
Database: "mkefu_location_dev_local",
URI: "jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?" +
"useUnicode=true&characterEncoding=UTF-8&serverTimezone=GMT%2B8",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if got := query.Get("charset"); got != "utf8mb4" {
t.Fatalf("JDBC characterEncoding=UTF-8 should map to MySQL charset utf8mb4, got=%q", got)
}
if got := query.Get("characterEncoding"); got != "" {
t.Fatalf("JDBC characterEncoding should not be passed to Go MySQL driver, got=%q", got)
}
if got := query.Get("serverTimezone"); got != "" {
t.Fatalf("JDBC serverTimezone should not be passed to Go MySQL driver, got=%q", got)
}
if got := query.Get("loc"); got != "Asia%2FShanghai" && got != "Asia/Shanghai" {
t.Fatalf("serverTimezone=GMT+8 should map to loc=Asia/Shanghai, got=%q", got)
}
}
func TestMySQLDSN_URIParamsAndExplicitParamsPrecedence(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "db.local",
Port: 3306,
User: "root",
Database: "app",
URI: "jdbc:mysql://db.local:3306/app?characterEncoding=utf8&timeout=15&topology=replica&useSSL=false",
ConnectionParams: "charset=utf8mb4&timeout=5s&socketTimeout=45000",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if got := query.Get("charset"); got != "utf8mb4" {
t.Fatalf("connectionParams should override URI charset, got=%q", got)
}
if got := query.Get("timeout"); got != "5s" {
t.Fatalf("connectionParams should override URI timeout, got=%q", got)
}
if got := query.Get("readTimeout"); got != "45s" {
t.Fatalf("socketTimeout should map to readTimeout duration, got=%q", got)
}
if _, exists := query["topology"]; exists {
t.Fatalf("internal topology parameter should not be passed to driver: %v", query)
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"math"
"net/url"
"sort"
"strconv"
"strings"
"time"
@@ -26,6 +27,183 @@ type MySQLDB struct {
const defaultMySQLPort = 3306
func parseMySQLCompatibleURI(raw string, allowedSchemes ...string) (*url.URL, bool) {
return parseConnectionURI(raw, allowedSchemes...)
}
func mysqlConnectionParamsFromText(raw string) url.Values {
return connectionParamsFromText(raw)
}
func parseMySQLBoolParam(raw string) (bool, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "1", "true", "yes", "on":
return true, true
case "0", "false", "no", "off":
return false, true
default:
return false, false
}
}
func normalizeMySQLDurationParam(raw string, unit time.Duration) string {
text := strings.TrimSpace(raw)
if text == "" {
return text
}
if n, err := strconv.Atoi(text); err == nil && n >= 0 {
return (time.Duration(n) * unit).String()
}
return text
}
func normalizeMySQLCharsetParam(raw string) string {
text := strings.TrimSpace(raw)
if text == "" {
return ""
}
lower := strings.ToLower(text)
switch lower {
case "utf-8", "utf_8", "unicode":
return "utf8mb4"
case "utf8", "utf8mb4", "latin1", "gbk", "gb2312", "gb18030", "big5", "sjis", "cp932":
return lower
case "iso-8859-1", "iso8859-1", "iso88591":
return "latin1"
default:
return text
}
}
func normalizeMySQLServerTimezoneParam(raw string) (string, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return "", false
}
compact := strings.ToUpper(strings.ReplaceAll(text, " ", ""))
switch compact {
case "LOCAL":
return "Local", true
case "UTC", "Z", "GMT", "GMT+0", "GMT-0", "GMT+00", "GMT-00", "GMT+00:00", "GMT-00:00",
"UTC+0", "UTC-0", "UTC+00", "UTC-00", "UTC+00:00", "UTC-00:00":
return "UTC", true
case "GMT+8", "GMT+08", "GMT+08:00", "UTC+8", "UTC+08", "UTC+08:00",
"ASIA/SHANGHAI", "PRC", "CTT":
return "Asia/Shanghai", true
}
if strings.Contains(text, "/") {
if _, err := time.LoadLocation(text); err == nil {
return text, true
}
}
return "", false
}
func mergeMySQLConnectionParam(params url.Values, key string, value string) {
name := strings.TrimSpace(key)
if name == "" {
return
}
lowerName := strings.ToLower(name)
switch lowerName {
case "topology":
return
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior":
return
case "charset":
if charset := normalizeMySQLCharsetParam(value); charset != "" {
params.Set("charset", charset)
}
return
case "characterencoding":
if charset := normalizeMySQLCharsetParam(value); charset != "" {
params.Set("charset", charset)
}
return
case "servertimezone":
if loc, ok := normalizeMySQLServerTimezoneParam(value); ok {
params.Set("loc", loc)
}
return
case "usessl":
if enabled, ok := parseMySQLBoolParam(value); ok {
if enabled {
params.Set("tls", "true")
} else {
params.Set("tls", "false")
}
}
return
case "verifyservercertificate":
if verified, ok := parseMySQLBoolParam(value); ok && !verified && params.Get("tls") != "false" {
params.Set("tls", "skip-verify")
}
return
case "trustservercertificate":
if trusted, ok := parseMySQLBoolParam(value); ok && trusted && params.Get("tls") != "false" {
params.Set("tls", "skip-verify")
}
return
case "connecttimeout":
params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond))
return
case "sockettimeout":
params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond))
return
case "timeout", "readtimeout", "writetimeout":
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
return
default:
params.Set(name, value)
}
}
func mergeMySQLConnectionParams(params url.Values, values url.Values) {
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
lowerName := strings.ToLower(strings.TrimSpace(key))
if lowerName == "verifyservercertificate" || lowerName == "trustservercertificate" {
continue
}
for _, value := range values[key] {
mergeMySQLConnectionParam(params, key, value)
}
}
for _, key := range keys {
lowerName := strings.ToLower(strings.TrimSpace(key))
if lowerName != "verifyservercertificate" && lowerName != "trustservercertificate" {
continue
}
for _, value := range values[key] {
mergeMySQLConnectionParam(params, key, value)
}
}
}
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string {
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
params := url.Values{}
params.Set("charset", "utf8mb4")
params.Set("parseTime", "True")
params.Set("loc", "Local")
params.Set("timeout", fmt.Sprintf("%ds", timeout))
params.Set("tls", tlsMode)
params.Set("multiStatements", "true")
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros"); ok {
mergeMySQLConnectionParams(params, parsed.Query())
}
mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams))
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?%s",
config.User, config.Password, protocol, address, database, params.Encode(),
)
}
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
text := strings.TrimSpace(raw)
if text == "" {
@@ -135,13 +313,8 @@ func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConf
if uriText == "" {
return config
}
lowerURI := strings.ToLower(uriText)
if !strings.HasPrefix(lowerURI, "mysql://") {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
parsed, ok := parseMySQLCompatibleURI(uriText, "mysql")
if !ok {
return config
}
@@ -239,13 +412,7 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
protocol = netName
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode),
), nil
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
}
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -30,3 +30,28 @@ func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) {
t.Fatalf("LOB FETCH = %q, want POST", got)
}
}
func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
t.Parallel()
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
Host: "db.example.com",
Port: 1521,
User: "scott",
Password: "tiger",
Database: "ORCLPDB1",
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc",
})
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("解析 Oracle DSN 失败: %v", err)
}
query := parsed.Query()
if got := query.Get("PREFETCH_ROWS"); got != "5000" {
t.Fatalf("PREFETCH_ROWS = %q, want 5000", got)
}
if got := query.Get("TRACE FILE"); got != "/tmp/go-ora.trc" {
t.Fatalf("TRACE FILE = %q, want /tmp/go-ora.trc", got)
}
}

View File

@@ -48,6 +48,7 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
q.Set("PREFETCH_ROWS", "10000")
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
q.Set("LOB FETCH", "POST")
mergeConnectionParamsFromConfig(q, config, "oracle")
if encoded := q.Encode(); encoded != "" {
u.RawQuery = encoded
}

View File

@@ -64,6 +64,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql")
u.RawQuery = q.Encode()
return u.String()

View File

@@ -50,6 +50,7 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
q.Set("encrypt", encrypt)
q.Set("TrustServerCertificate", trustServerCertificate)
mergeConnectionParamsFromConfig(q, config, "sqlserver")
u.RawQuery = q.Encode()
return u.String()

View File

@@ -7,6 +7,7 @@ import (
"database/sql"
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
@@ -42,7 +43,16 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
}
netType := resolveTDengineNet(config)
return fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
params := url.Values{}
mergeConnectionParamsFromConfig(params, config, "taos", "taosws", "tdengine")
params.Del("protocol")
params.Del("skip_verify")
query := params.Encode()
dsn := fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
if query == "" {
return dsn
}
return dsn + "?" + query
}
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {

View File

@@ -43,6 +43,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "vastbase")
u.RawQuery = q.Encode()
return u.String()