diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 2ea26d2..77341ce 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -61,6 +61,7 @@ import { } from "../utils/connectionModalPresentation"; import { resolveConnectionSecretDraft } from "../utils/connectionSecretDraft"; import { getCustomConnectionDsnValidationMessage } from "../utils/customConnectionDsn"; +import { mergeParsedUriValuesForForm } from "../utils/connectionUriMerge"; import { CUSTOM_CONNECTION_DRIVER_HELP } from "../utils/driverImportGuidance"; import { applyNoAutoCapAttributes, @@ -97,6 +98,7 @@ type ChoiceCardOption = { }; type ClickHouseProtocolChoice = "auto" | "http" | "native"; const MAX_URI_LENGTH = 4096; +const MAX_CONNECTION_PARAMS_LENGTH = 4096; const MAX_URI_HOSTS = 32; const MAX_TIMEOUT_SECONDS = 3600; const CONNECTION_MODAL_WIDTH = 960; @@ -232,6 +234,26 @@ const supportsSSLForType = (type: string) => const isFileDatabaseType = (type: string) => type === "sqlite" || type === "duckdb"; +const isMySQLCompatibleType = (type: string) => + type === "mysql" || + type === "mariadb" || + type === "doris" || + type === "diros" || + type === "sphinx"; + +const supportsConnectionParamsForType = (type: string) => + isMySQLCompatibleType(type) || + type === "postgres" || + type === "kingbase" || + type === "highgo" || + type === "vastbase" || + type === "oracle" || + type === "sqlserver" || + type === "clickhouse" || + type === "mongodb" || + type === "dameng" || + type === "tdengine"; + type DriverStatusSnapshot = { type: string; name: string; @@ -355,12 +377,8 @@ const ConnectionModal: React.FC<{ }), [jvmAllowedModes, jvmPreferredMode], ); - const isMySQLLike = - dbType === "mysql" || - dbType === "mariadb" || - dbType === "doris" || - dbType === "diros" || - dbType === "sphinx"; + const isMySQLLike = isMySQLCompatibleType(dbType); + const supportsConnectionParams = supportsConnectionParamsForType(dbType); const isSSLType = supportsSSLForType(dbType); const sslHintText = isMySQLLike ? "当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL;本地自签证书场景可先用 Preferred 或 Skip Verify。" @@ -1047,6 +1065,44 @@ const ConnectionModal: React.FC<{ return text === "1" || text === "true" || text === "yes" || text === "on"; }; + const normalizeConnectionParamsText = (raw: unknown) => { + let text = String(raw || "").trim(); + if (!text) return ""; + const queryIndex = text.indexOf("?"); + if (queryIndex >= 0) { + text = text.slice(queryIndex + 1); + } + const hashIndex = text.indexOf("#"); + if (hashIndex >= 0) { + text = text.slice(0, hashIndex); + } + return text.replace(/^[?&]+/, "").trim().slice(0, MAX_CONNECTION_PARAMS_LENGTH); + }; + + const serializeConnectionParams = (params: URLSearchParams) => { + const cloned = new URLSearchParams(); + params.forEach((value, key) => { + if (String(key || "").trim()) { + cloned.append(key, value); + } + }); + return cloned.toString().slice(0, MAX_CONNECTION_PARAMS_LENGTH); + }; + + const mergeConnectionParams = ( + params: URLSearchParams, + rawParams: unknown, + ) => { + const text = normalizeConnectionParamsText(rawParams); + if (!text) return; + const extra = new URLSearchParams(text); + extra.forEach((value, key) => { + if (String(key || "").trim()) { + params.set(key, value); + } + }); + }; + const normalizeFileDbPath = (rawPath: string): string => { let pathText = String(rawPath || "").trim(); if (!pathText) { @@ -1199,6 +1255,7 @@ const ConnectionModal: React.FC<{ clickHouseProtocol: "http", useSSL: isHttps, sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable", + connectionParams: serializeConnectionParams(parsed.params), }; }; @@ -1214,15 +1271,11 @@ const ConnectionModal: React.FC<{ return null; } - if ( - type === "mysql" || - type === "mariadb" || - type === "diros" || - type === "sphinx" - ) { + if (isMySQLCompatibleType(type)) { const mysqlDefaultPort = getDefaultPortByType(type); const parsed = parseMultiHostUri(trimmedUri, "mysql") || + parseMultiHostUri(trimmedUri, "jdbc:mysql") || parseMultiHostUri(trimmedUri, "diros") || parseMultiHostUri(trimmedUri, "doris"); if (!parsed) { @@ -1246,7 +1299,9 @@ const ConnectionModal: React.FC<{ const topology = String( parsed.params.get("topology") || "", ).toLowerCase(); - const tlsValue = String(parsed.params.get("tls") || "") + const tlsValue = String( + parsed.params.get("tls") || parsed.params.get("useSSL") || "", + ) .trim() .toLowerCase(); const sslMode = @@ -1268,6 +1323,7 @@ const ConnectionModal: React.FC<{ mysqlTopology: hostList.length > 1 || topology === "replica" ? "replica" : "single", mysqlReplicaHosts: hostList.slice(1), + connectionParams: serializeConnectionParams(parsed.params), timeout: Number.isFinite(timeoutValue) && timeoutValue > 0 ? Math.min(3600, Math.trunc(timeoutValue)) @@ -1414,6 +1470,7 @@ const ConnectionModal: React.FC<{ mongoAuthSource: parsed.params.get("authSource") || "", mongoReadPreference: parsed.params.get("readPreference") || "primary", mongoAuthMechanism: parsed.params.get("authMechanism") || "", + connectionParams: serializeConnectionParams(parsed.params), timeout: Number.isFinite(timeoutMs) && timeoutMs > 0 ? Math.min(MAX_TIMEOUT_SECONDS, Math.ceil(timeoutMs / 1000)) @@ -1450,6 +1507,9 @@ const ConnectionModal: React.FC<{ password: parsed.password, database: parsed.database, }; + if (supportsConnectionParamsForType(type)) { + parsedValues.connectionParams = serializeConnectionParams(parsed.params); + } if (supportsSSLForType(type)) { const normalizeBool = (raw: unknown) => { @@ -1619,12 +1679,7 @@ const ConnectionModal: React.FC<{ }); const getUriPlaceholder = () => { - if ( - dbType === "mysql" || - dbType === "mariadb" || - dbType === "diros" || - dbType === "sphinx" - ) { + if (isMySQLCompatibleType(dbType)) { const defaultPort = getDefaultPortByType(dbType); const scheme = dbType === "diros" ? "doris" : "mysql"; return `${scheme}://user:pass@127.0.0.1:${defaultPort},127.0.0.2:${defaultPort}/db_name?topology=replica`; @@ -1649,6 +1704,33 @@ const ConnectionModal: React.FC<{ return "例如: postgres://user:pass@127.0.0.1:5432/db_name"; }; + const getConnectionParamsPlaceholder = () => { + if (isMySQLCompatibleType(dbType)) { + return "useUnicode=true&characterEncoding=utf8&autoReconnect=true&useSSL=false"; + } + switch (dbType) { + case "postgres": + case "kingbase": + case "highgo": + case "vastbase": + return "application_name=GoNavi&statement_timeout=30000"; + case "oracle": + return "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc"; + case "sqlserver": + return "app name=GoNavi&packet size=32767"; + case "clickhouse": + return "max_execution_time=60&compress=lz4"; + case "mongodb": + return "retryWrites=true&readPreference=secondaryPreferred"; + case "dameng": + return "schema=SYSDBA&escapeProcess=true"; + case "tdengine": + return "timezone=Asia%2FShanghai"; + default: + return "key=value&another=value"; + } + }; + const buildUriFromValues = (values: any) => { const type = String(values.type || "") .trim() @@ -1664,12 +1746,7 @@ const ConnectionModal: React.FC<{ ? `${encodeURIComponent(user)}${password ? `:${encodeURIComponent(password)}` : ""}@` : ""; - if ( - type === "mysql" || - type === "mariadb" || - type === "diros" || - type === "sphinx" - ) { + if (isMySQLCompatibleType(type)) { const primary = toAddress(host, port, defaultPort); const replicas = values.mysqlTopology === "replica" @@ -1695,6 +1772,7 @@ const ConnectionModal: React.FC<{ if (Number.isFinite(timeout) && timeout > 0) { params.set("timeout", String(timeout)); } + mergeConnectionParams(params, values.connectionParams); const dbPath = database ? `/${encodeURIComponent(database)}` : "/"; const query = params.toString(); const scheme = type === "diros" ? "doris" : "mysql"; @@ -1797,6 +1875,7 @@ const ConnectionModal: React.FC<{ params.set("connectTimeoutMS", String(timeout * 1000)); params.set("serverSelectionTimeoutMS", String(timeout * 1000)); } + mergeConnectionParams(params, values.connectionParams); const dbPath = database ? `/${encodeURIComponent(database)}` : "/"; const query = params.toString(); return `${scheme}://${encodedAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`; @@ -1876,6 +1955,9 @@ const ConnectionModal: React.FC<{ if (type === "clickhouse" && clickHouseProtocol !== "auto") { params.set("protocol", clickHouseProtocol); } + if (supportsConnectionParamsForType(type)) { + mergeConnectionParams(params, values.connectionParams); + } const query = params.toString(); return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ""}`; }; @@ -1909,7 +1991,13 @@ const ConnectionModal: React.FC<{ }); return; } - form.setFieldsValue({ ...parsedValues, uri: uriText }); + form.setFieldsValue( + mergeParsedUriValuesForForm( + form.getFieldsValue(true), + parsedValues, + uriText, + ), + ); if (testResult) { setTestResult(null); } @@ -2082,6 +2170,11 @@ const ConnectionModal: React.FC<{ password: config.password, database: config.database, uri: config.uri || "", + connectionParams: + config.connectionParams || + (config.uri + ? parseUriToValues(config.uri, configType)?.connectionParams || "" + : ""), clickHouseProtocol: configType === "clickhouse" ? normalizeClickHouseProtocolValue(config.clickHouseProtocol) @@ -2294,11 +2387,7 @@ const ConnectionModal: React.FC<{ forceClear: !config.useHttpTunnel, }); const mysqlReplicaEnabled = - (config.type === "mysql" || - config.type === "mariadb" || - config.type === "diros" || - config.type === "sphinx") && - config.topology === "replica"; + isMySQLCompatibleType(config.type) && config.topology === "replica"; const mysqlReplicaDraft = resolveConnectionSecretDraft({ hasSecret: initialValues?.hasMySQLReplicaPassword, valueInput: config.mysqlReplicaPassword, @@ -2528,10 +2617,7 @@ const ConnectionModal: React.FC<{ } if ( clearSecrets.mysqlReplicaPassword && - (values.type === "mysql" || - values.type === "mariadb" || - values.type === "diros" || - values.type === "sphinx") && + isMySQLCompatibleType(values.type) && values.mysqlTopology === "replica" && String(values.mysqlReplicaPassword ?? "") === "" ) { @@ -2983,12 +3069,7 @@ const ConnectionModal: React.FC<{ const savePassword = type === "mongodb" ? mergedValues.savePassword !== false : true; - if ( - type === "mysql" || - type === "mariadb" || - type === "diros" || - type === "sphinx" - ) { + if (isMySQLCompatibleType(type)) { const replicas = mergedValues.mysqlTopology === "replica" ? normalizeAddressList(mergedValues.mysqlReplicaHosts, defaultPort) @@ -3150,6 +3231,9 @@ const ConnectionModal: React.FC<{ httpTunnel: httpTunnelConfig, driver: mergedValues.driver, dsn: mergedValues.dsn, + connectionParams: supportsConnectionParamsForType(type) + ? normalizeConnectionParamsText(mergedValues.connectionParams) + : "", timeout: Number(mergedValues.timeout || 30), redisDB: Number.isFinite(Number(mergedValues.redisDB)) ? Math.max(0, Math.min(15, Math.trunc(Number(mergedValues.redisDB)))) @@ -3226,6 +3310,7 @@ const ConnectionModal: React.FC<{ httpTunnelPassword: "", timeout: 30, uri: "", + connectionParams: "", includeDatabases: undefined, includeRedisDatabases: undefined, mysqlTopology: "single", @@ -3303,6 +3388,7 @@ const ConnectionModal: React.FC<{ mongoReplicaUser: "", mongoReplicaPassword: "", redisDB: 0, + connectionParams: "", }); } else if (type !== "custom") { const defaultUser = @@ -3340,6 +3426,7 @@ const ConnectionModal: React.FC<{ mongoReplicaUser: "", mongoReplicaPassword: "", redisDB: 0, + connectionParams: "", }); } @@ -3749,6 +3836,19 @@ const ConnectionModal: React.FC<{ placeholder={getUriPlaceholder()} /> + {supportsConnectionParams && ( + + + + )} { useHttpTunnel, httpTunnel, uri: toTrimmedString(raw.uri).slice(0, MAX_URI_LENGTH), + connectionParams: toTrimmedString(raw.connectionParams).slice( + 0, + MAX_URI_LENGTH, + ), hosts: sanitizeAddressList(raw.hosts), topology: raw.topology === "replica" diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 9067cc0..c952504 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -294,6 +294,7 @@ export interface ConnectionConfig { httpTunnel?: HTTPTunnelConfig; driver?: string; dsn?: string; + connectionParams?: string; timeout?: number; redisDB?: number; // Redis database index (0-15) uri?: string; // Connection URI for copy/paste diff --git a/frontend/src/utils/connectionRpcConfig.test.ts b/frontend/src/utils/connectionRpcConfig.test.ts index 05d4e41..90cfdec 100644 --- a/frontend/src/utils/connectionRpcConfig.test.ts +++ b/frontend/src/utils/connectionRpcConfig.test.ts @@ -52,6 +52,19 @@ describe('buildRpcConnectionConfig', () => { expect(result.clickHouseProtocol).toBe('http'); }); + it('preserves extra connection params for RPC calls', () => { + const result = buildRpcConnectionConfig({ + id: 'conn-mysql', + type: 'mysql', + host: 'db.local', + port: 3306, + user: 'root', + connectionParams: 'characterEncoding=utf8&useSSL=false', + } as any); + + expect(result.connectionParams).toBe('characterEncoding=utf8&useSSL=false'); + }); + it('fills default nested config blocks needed by RPC calls', () => { const result = buildRpcConnectionConfig({ id: 'conn-redis', diff --git a/frontend/src/utils/connectionUriMerge.test.ts b/frontend/src/utils/connectionUriMerge.test.ts new file mode 100644 index 0000000..b679bda --- /dev/null +++ b/frontend/src/utils/connectionUriMerge.test.ts @@ -0,0 +1,77 @@ +import { describe, expect, it } from "vitest"; + +import { mergeParsedUriValuesForForm } from "./connectionUriMerge"; + +describe("mergeParsedUriValuesForForm", () => { + it("keeps saved credentials when parsed URI has no auth section", () => { + const result = mergeParsedUriValuesForForm( + { + user: "root", + password: "saved-password", + host: "192.168.1.10", + port: 3306, + database: "old_db", + connectionParams: "application_name=GoNavi", + timeout: 30, + }, + { + host: "192.168.1.240", + port: 3306, + user: "", + password: "", + database: "mkefu_location_dev_local", + connectionParams: "", + timeout: undefined, + useSSL: false, + }, + "jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?characterEncoding=UTF-8", + ); + + expect(result).toMatchObject({ + uri: "jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?characterEncoding=UTF-8", + host: "192.168.1.240", + port: 3306, + database: "mkefu_location_dev_local", + useSSL: false, + }); + expect(result).not.toHaveProperty("user"); + expect(result).not.toHaveProperty("password"); + expect(result).not.toHaveProperty("connectionParams"); + expect(result).not.toHaveProperty("timeout"); + }); + + it("allows URI credentials to replace existing credentials when provided", () => { + const result = mergeParsedUriValuesForForm( + { + user: "root", + password: "old-password", + }, + { + user: "uri_user", + password: "uri-password", + }, + "mysql://uri_user:uri-password@127.0.0.1:3306/app", + ); + + expect(result).toMatchObject({ + user: "uri_user", + password: "uri-password", + }); + }); + + it("keeps existing database when URI omits a database path", () => { + const result = mergeParsedUriValuesForForm( + { + database: "saved_db", + }, + { + host: "127.0.0.1", + database: "", + }, + "mysql://127.0.0.1:3306", + ); + + expect(result.database).toBeUndefined(); + expect(result.host).toBe("127.0.0.1"); + }); +}); diff --git a/frontend/src/utils/connectionUriMerge.ts b/frontend/src/utils/connectionUriMerge.ts new file mode 100644 index 0000000..3712877 --- /dev/null +++ b/frontend/src/utils/connectionUriMerge.ts @@ -0,0 +1,36 @@ +const EMPTY_PRESERVED_URI_FIELDS = new Set([ + "user", + "password", + "database", + "connectionParams", +]); + +const isEmptyParsedValue = (value: unknown): boolean => + value === undefined || + value === null || + value === "" || + (Array.isArray(value) && value.length === 0); + +export const mergeParsedUriValuesForForm = ( + currentValues: Record, + parsedValues: Record, + uriText: string, +): Record => { + const nextValues: Record = { uri: uriText }; + + Object.entries(parsedValues).forEach(([key, value]) => { + if (value === undefined) { + return; + } + if ( + EMPTY_PRESERVED_URI_FIELDS.has(key) && + isEmptyParsedValue(value) && + !isEmptyParsedValue(currentValues[key]) + ) { + return; + } + nextValues[key] = value; + }); + + return nextValues; +}; diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index 06c8b9b..a1c858e 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -667,6 +667,7 @@ export namespace connection { httpTunnel?: HTTPTunnelConfig; driver?: string; dsn?: string; + connectionParams?: string; timeout?: number; redisDB?: number; uri?: string; @@ -710,6 +711,7 @@ export namespace connection { this.httpTunnel = this.convertValues(source["httpTunnel"], HTTPTunnelConfig); this.driver = source["driver"]; this.dsn = source["dsn"]; + this.connectionParams = source["connectionParams"]; this.timeout = source["timeout"]; this.redisDB = source["redisDB"]; this.uri = source["uri"]; diff --git a/internal/app/app.go b/internal/app/app.go index e9aff14..25943b2 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -209,6 +209,7 @@ func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.Conn normalized.User = "" normalized.Password = "" normalized.URI = "" + normalized.ConnectionParams = "" normalized.Hosts = nil normalized.Topology = "" normalized.MySQLReplicaUser = "" @@ -450,6 +451,9 @@ func formatConnSummary(config connection.ConnectionConfig) string { if strings.TrimSpace(config.URI) != "" { b.WriteString(fmt.Sprintf(" URI=已配置(长度=%d)", len(config.URI))) } + if strings.TrimSpace(config.ConnectionParams) != "" { + b.WriteString(fmt.Sprintf(" 连接参数=已配置(长度=%d)", len(config.ConnectionParams))) + } if strings.TrimSpace(config.MySQLReplicaUser) != "" { b.WriteString(" MySQL从库凭据=已配置") } diff --git a/internal/app/app_cache_key_test.go b/internal/app/app_cache_key_test.go index de4b616..4d84fe9 100644 --- a/internal/app/app_cache_key_test.go +++ b/internal/app/app_cache_key_test.go @@ -81,6 +81,26 @@ func TestGetCacheKey_KeepDatabaseIsolation(t *testing.T) { } } +func TestGetCacheKey_KeepConnectionParamsIsolation(t *testing.T) { + base := connection.ConnectionConfig{ + Type: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Password: "root", + Database: "app", + ConnectionParams: "charset=utf8", + } + modified := base + modified.ConnectionParams = "charset=utf8mb4" + + left := getCacheKey(base) + right := getCacheKey(modified) + if left == right { + t.Fatalf("expected different cache key for different connection params") + } +} + func TestGetCacheKey_KeepClickHouseProtocolIsolation(t *testing.T) { base := connection.ConnectionConfig{ Type: "clickhouse", diff --git a/internal/connection/types.go b/internal/connection/types.go index 6fef564..d75a390 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -99,6 +99,7 @@ type ConnectionConfig struct { HTTPTunnel HTTPTunnelConfig `json:"httpTunnel,omitempty"` Driver string `json:"driver,omitempty"` // For custom connection DSN string `json:"dsn,omitempty"` // For custom connection + ConnectionParams string `json:"connectionParams,omitempty"` // Extra URI query parameters for built-in drivers Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds (default: 30) RedisDB int `json:"redisDB,omitempty"` // Redis database index (0-15) URI string `json:"uri,omitempty"` // Connection URI for copy/paste diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index c3c63f0..cb2bcc3 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -106,6 +106,12 @@ func applyClickHouseEndpointURI(config connection.ConnectionConfig, uriText stri if queryProtocol := normalizeClickHouseProtocol(parsed.Query().Get("protocol")); queryProtocol != clickHouseProtocolAuto { config.ClickHouseProtocol = queryProtocol } + if parsed.RawQuery != "" { + params := url.Values{} + mergeConnectionParamValues(params, parsed.Query()) + mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams)) + config.ConnectionParams = params.Encode() + } endpointProtocol := normalizeClickHouseProtocol(config.ClickHouseProtocol) if isClickHouseHTTPURLScheme(scheme) && endpointProtocol != clickHouseProtocolNative { config.ClickHouseProtocol = clickHouseProtocolHTTP @@ -184,9 +190,148 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil { opts.TLS = tlsConfig } + applyClickHouseConnectionParams(opts, config) return opts } +func parseClickHouseDurationParam(raw string) (time.Duration, bool) { + text := strings.TrimSpace(raw) + if text == "" { + return 0, false + } + if n, err := strconv.Atoi(text); err == nil && n >= 0 { + return time.Duration(n) * time.Second, true + } + duration, err := time.ParseDuration(text) + return duration, err == nil +} + +func parseClickHouseIntParam(raw string) (int, bool) { + n, err := strconv.Atoi(strings.TrimSpace(raw)) + return n, err == nil +} + +func clickHouseSettingValue(raw string) any { + text := strings.TrimSpace(raw) + switch strings.ToLower(text) { + case "true", "yes", "on": + return int(1) + case "false", "no", "off": + return int(0) + } + if n, err := strconv.Atoi(text); err == nil { + return n + } + return text +} + +func applyClickHouseCompressionParam(opts *clickhouse.Options, raw string) { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" || value == "false" || value == "0" || value == "none" { + opts.Compression = &clickhouse.Compression{Method: clickhouse.CompressionNone} + return + } + if opts.Compression == nil { + opts.Compression = &clickhouse.Compression{Level: 3} + } + switch value { + case "true", "1", "lz4": + opts.Compression.Method = clickhouse.CompressionLZ4 + case "zstd": + opts.Compression.Method = clickhouse.CompressionZSTD + case "lz4hc": + opts.Compression.Method = clickhouse.CompressionLZ4HC + case "gzip": + opts.Compression.Method = clickhouse.CompressionGZIP + case "deflate": + opts.Compression.Method = clickhouse.CompressionDeflate + case "br", "brotli": + opts.Compression.Method = clickhouse.CompressionBrotli + } +} + +func applyClickHouseConnectionParams(opts *clickhouse.Options, config connection.ConnectionConfig) { + params := url.Values{} + mergeConnectionParamsFromConfig(params, config, "clickhouse", "http", "https") + if len(params) == 0 { + return + } + if opts.Settings == nil { + opts.Settings = clickhouse.Settings{} + } + keys := make([]string, 0, len(params)) + for key := range params { + if strings.TrimSpace(key) != "" { + keys = append(keys, key) + } + } + sort.Strings(keys) + for _, key := range keys { + values := params[key] + if len(values) == 0 { + continue + } + value := values[len(values)-1] + switch strings.ToLower(strings.TrimSpace(key)) { + case "protocol", "secure", "skip_verify", "username", "password", "database": + continue + case "dial_timeout": + if duration, ok := parseClickHouseDurationParam(value); ok { + opts.DialTimeout = duration + } + case "read_timeout": + if duration, ok := parseClickHouseDurationParam(value); ok { + opts.ReadTimeout = duration + } + case "compress": + applyClickHouseCompressionParam(opts, value) + case "compress_level": + if level, ok := parseClickHouseIntParam(value); ok { + if opts.Compression == nil { + opts.Compression = &clickhouse.Compression{Method: clickhouse.CompressionNone} + } + opts.Compression.Level = level + } + case "max_open_conns": + if n, ok := parseClickHouseIntParam(value); ok { + opts.MaxOpenConns = n + } + case "max_idle_conns": + if n, ok := parseClickHouseIntParam(value); ok { + opts.MaxIdleConns = n + } + case "max_compression_buffer": + if n, ok := parseClickHouseIntParam(value); ok { + opts.MaxCompressionBuffer = n + } + case "block_buffer_size": + if n, ok := parseClickHouseIntParam(value); ok && n > 0 && n <= 255 { + opts.BlockBufferSize = uint8(n) + } + case "http_path": + path := strings.TrimSpace(value) + if path != "" && !strings.HasPrefix(path, "/") { + path = "/" + path + } + opts.HttpUrlPath = path + case "connection_open_strategy": + switch strings.ToLower(strings.TrimSpace(value)) { + case "in_order": + opts.ConnOpenStrategy = clickhouse.ConnOpenInOrder + case "round_robin": + opts.ConnOpenStrategy = clickhouse.ConnOpenRoundRobin + case "random": + opts.ConnOpenStrategy = clickhouse.ConnOpenRandom + } + default: + opts.Settings[key] = clickHouseSettingValue(value) + } + } + if len(opts.Settings) == 0 { + opts.Settings = nil + } +} + func detectClickHouseProtocol(config connection.ConnectionConfig) clickhouse.Protocol { switch normalizeClickHouseProtocol(config.ClickHouseProtocol) { case clickHouseProtocolHTTP: diff --git a/internal/db/connection_params.go b/internal/db/connection_params.go new file mode 100644 index 0000000..cceb625 --- /dev/null +++ b/internal/db/connection_params.go @@ -0,0 +1,117 @@ +package db + +import ( + "net/url" + "sort" + "strings" + + "GoNavi-Wails/internal/connection" +) + +func parseConnectionURI(raw string, allowedSchemes ...string) (*url.URL, bool) { + text := strings.TrimSpace(raw) + if text == "" { + return nil, false + } + if strings.HasPrefix(strings.ToLower(text), "jdbc:") { + text = strings.TrimSpace(text[len("jdbc:"):]) + } + parsed, err := url.Parse(text) + if err != nil { + return nil, false + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + for _, allowed := range allowedSchemes { + if scheme == strings.ToLower(strings.TrimSpace(allowed)) { + return parsed, true + } + } + return nil, false +} + +func connectionParamsFromText(raw string) url.Values { + text := strings.TrimSpace(raw) + if text == "" { + return nil + } + if queryIndex := strings.Index(text, "?"); queryIndex >= 0 { + text = text[queryIndex+1:] + } + if hashIndex := strings.Index(text, "#"); hashIndex >= 0 { + text = text[:hashIndex] + } + text = strings.TrimLeft(strings.TrimSpace(text), "?&") + if text == "" { + return nil + } + values, err := url.ParseQuery(text) + if err != nil { + return nil + } + return values +} + +func connectionParamsFromURI(raw string, allowedSchemes ...string) url.Values { + parsed, ok := parseConnectionURI(raw, allowedSchemes...) + if !ok { + return nil + } + return parsed.Query() +} + +func mergeConnectionParamValues(params url.Values, values url.Values) { + if len(values) == 0 { + return + } + keys := make([]string, 0, len(values)) + for key := range values { + if strings.TrimSpace(key) != "" { + keys = append(keys, key) + } + } + sort.Strings(keys) + for _, key := range keys { + for _, value := range values[key] { + params.Set(key, value) + } + } +} + +func mergeConnectionParamsFromConfig(params url.Values, config connection.ConnectionConfig, allowedSchemes ...string) { + mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, allowedSchemes...)) + mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams)) +} + +func mergeConnectionParamsIntoRawURI(raw string, connectionParams string, allowedSchemes ...string) string { + text := strings.TrimSpace(raw) + if text == "" { + return text + } + parsed, ok := parseConnectionURI(text, allowedSchemes...) + if !ok { + return text + } + params := parsed.Query() + mergeConnectionParamValues(params, connectionParamsFromText(connectionParams)) + parsed.RawQuery = params.Encode() + return parsed.String() +} + +func isSafeConnectionParamKey(key string) bool { + text := strings.TrimSpace(key) + if text == "" { + return false + } + for _, r := range text { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' { + continue + } + switch r { + case '_', '-', '.', ' ': + continue + default: + return false + } + } + return true +} diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go index 0159b9a..d606b1b 100644 --- a/internal/db/dameng_impl.go +++ b/internal/db/dameng_impl.go @@ -48,6 +48,7 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string { // 达梦驱动要求:密码包含特殊字符时,password 需 PathEscape,并添加 escapeProcess=true 让驱动解码。 q.Set("escapeProcess", "true") } + mergeConnectionParamsFromConfig(q, config, "dm", "dameng") dsn := fmt.Sprintf("dm://%s:%s@%s", config.User, escapedPassword, address) encoded := q.Encode() diff --git a/internal/db/diros_impl.go b/internal/db/diros_impl.go index 3af7465..38d7664 100644 --- a/internal/db/diros_impl.go +++ b/internal/db/diros_impl.go @@ -5,7 +5,6 @@ package db import ( "database/sql" "fmt" - "net/url" "strings" "GoNavi-Wails/internal/connection" @@ -40,15 +39,8 @@ func applyDirosURI(config connection.ConnectionConfig) connection.ConnectionConf return config } - lowerURI := strings.ToLower(uriText) - if !strings.HasPrefix(lowerURI, "diros://") && - !strings.HasPrefix(lowerURI, "doris://") && - !strings.HasPrefix(lowerURI, "mysql://") { - return config - } - - parsed, err := url.Parse(uriText) - if err != nil { + parsed, ok := parseMySQLCompatibleURI(uriText, "diros", "doris", "mysql") + if !ok { return config } @@ -147,13 +139,7 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) { protocol = netName } - timeout := getConnectTimeoutSeconds(config) - tlsMode := resolveMySQLTLSMode(config) - - return fmt.Sprintf( - "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), - ), nil + return buildMySQLCompatibleDSN(config, protocol, address, database), nil } func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { diff --git a/internal/db/dsn_test.go b/internal/db/dsn_test.go index 6bed9fe..0deee14 100644 --- a/internal/db/dsn_test.go +++ b/internal/db/dsn_test.go @@ -3,11 +3,14 @@ package db import ( + "net/url" "strings" "testing" "time" "GoNavi-Wails/internal/connection" + + clickhouse "github.com/ClickHouse/clickhouse-go/v2" ) func TestPostgresDSN_EscapesPassword(t *testing.T) { @@ -52,6 +55,32 @@ func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) { } } +func TestPostgresDSN_MergesConnectionParams(t *testing.T) { + p := &PostgresDB{} + cfg := connection.ConnectionConfig{ + Type: "postgres", + Host: "127.0.0.1", + Port: 5432, + User: "user", + Password: "pass", + Database: "db", + ConnectionParams: "application_name=GoNavi&connect_timeout=9", + } + + dsn := p.getDSN(cfg) + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("parse postgres dsn: %v", err) + } + query := parsed.Query() + if got := query.Get("application_name"); got != "GoNavi" { + t.Fatalf("application_name = %q, want GoNavi", got) + } + if got := query.Get("connect_timeout"); got != "9" { + t.Fatalf("connect_timeout = %q, want 9", got) + } +} + func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) { m := &MySQLDB{} cfg := connection.ConnectionConfig{ @@ -65,7 +94,10 @@ func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) { SSLMode: "required", } - dsn := m.getDSN(cfg) + dsn, err := m.getDSN(cfg) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } if !strings.Contains(dsn, "tls=true") { t.Fatalf("dsn 缺少 tls=true 参数:%s", dsn) } @@ -161,6 +193,27 @@ func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) { } } +func TestKingbaseDSN_MergesConnectionParams(t *testing.T) { + k := &KingbaseDB{} + cfg := connection.ConnectionConfig{ + Type: "kingbase", + Host: "127.0.0.1", + Port: 54321, + User: "system", + Password: "pass", + Database: "TEST", + ConnectionParams: "application_name=GoNavi&connect_timeout=12", + } + + dsn := k.getDSN(cfg) + if !strings.Contains(dsn, "application_name=GoNavi") { + t.Fatalf("dsn 缺少 application_name:%s", dsn) + } + if !strings.Contains(dsn, "connect_timeout=12") { + t.Fatalf("dsn 缺少自定义 connect_timeout:%s", dsn) + } +} + func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) { td := &TDengineDB{} cfg := connection.ConnectionConfig{ @@ -197,6 +250,24 @@ func TestTDengineDSN_UsesSecureWebSocketWhenSSLEnabled(t *testing.T) { } } +func TestTDengineDSN_MergesConnectionParams(t *testing.T) { + td := &TDengineDB{} + cfg := connection.ConnectionConfig{ + Type: "tdengine", + Host: "127.0.0.1", + Port: 6041, + User: "root", + Password: "taosdata", + Database: "power", + ConnectionParams: "timezone=Asia%2FShanghai&protocol=wss", + } + + dsn := td.getDSN(cfg) + if !strings.Contains(dsn, "?timezone=Asia%2FShanghai") { + t.Fatalf("tdengine dsn 缺少自定义参数或错误透传 protocol:%s", dsn) + } +} + func TestSQLServerDSN_EncryptMapping(t *testing.T) { s := &SqlServerDB{} cfg := connection.ConnectionConfig{ @@ -219,6 +290,32 @@ func TestSQLServerDSN_EncryptMapping(t *testing.T) { } } +func TestSQLServerDSN_MergesConnectionParams(t *testing.T) { + s := &SqlServerDB{} + cfg := connection.ConnectionConfig{ + Type: "sqlserver", + Host: "127.0.0.1", + Port: 1433, + User: "sa", + Password: "pass", + Database: "master", + ConnectionParams: "app name=GoNavi&packet size=32767", + } + + dsn := s.getDSN(cfg) + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("parse sqlserver dsn: %v", err) + } + query := parsed.Query() + if got := query.Get("app name"); got != "GoNavi" { + t.Fatalf("app name = %q, want GoNavi", got) + } + if got := query.Get("packet size"); got != "32767" { + t.Fatalf("packet size = %q, want 32767", got) + } +} + func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) { c := &ClickHouseDB{} cfg := normalizeClickHouseConfig(connection.ConnectionConfig{ @@ -264,6 +361,34 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) { } } +func TestClickHouseOptions_MergesConnectionParamsIntoOptionsAndSettings(t *testing.T) { + c := &ClickHouseDB{} + cfg := normalizeClickHouseConfig(connection.ConnectionConfig{ + Type: "clickhouse", + Host: "127.0.0.1", + Port: 9000, + User: "default", + Password: "secret", + Database: "analytics", + Timeout: 15, + ConnectionParams: "max_execution_time=60&compress=lz4&read_timeout=10s", + }) + + opts := c.buildClickHouseOptions(cfg) + if opts == nil { + t.Fatal("options 为空") + } + if opts.ReadTimeout != 10*time.Second { + t.Fatalf("read timeout 不符合预期:%s", opts.ReadTimeout) + } + if opts.Compression == nil || opts.Compression.Method != clickhouse.CompressionLZ4 { + t.Fatalf("compression 不符合预期:%v", opts.Compression) + } + if got := opts.Settings["max_execution_time"]; got != 60 { + t.Fatalf("max_execution_time = %#v, want 60", got) + } +} + func TestClickHouseOptions_ReadTimeoutUsesLargerConfiguredTimeout(t *testing.T) { c := &ClickHouseDB{} cfg := normalizeClickHouseConfig(connection.ConnectionConfig{ diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index 87c9f4b..3e30b22 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -44,6 +44,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string { q := url.Values{} q.Set("sslmode", resolvePostgresSSLMode(config)) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) + mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "highgo") u.RawQuery = q.Encode() return u.String() diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 8afb619..ab35b42 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -7,7 +7,9 @@ import ( "database/sql" "fmt" "net" + "net/url" "regexp" + "sort" "strconv" "strings" "time" @@ -62,21 +64,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { // Kingbase DSN usually similar to Postgres: // host=localhost port=54321 user=system password=... dbname=TEST sslmode=disable - address := config.Host - port := config.Port + params := url.Values{} + params.Set("host", config.Host) + params.Set("port", strconv.Itoa(config.Port)) + params.Set("user", config.User) + params.Set("password", config.Password) + params.Set("dbname", config.Database) + params.Set("sslmode", resolvePostgresSSLMode(config)) + params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) + mergeConnectionParamsFromConfig(params, config, "kingbase") - // Construct DSN - dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d", - quoteConnValue(address), - port, - quoteConnValue(config.User), - quoteConnValue(config.Password), - quoteConnValue(config.Database), - quoteConnValue(resolvePostgresSSLMode(config)), - getConnectTimeoutSeconds(config), - ) + preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"} + seen := make(map[string]struct{}, len(params)) + parts := make([]string, 0, len(params)) + for _, key := range preferred { + if values, ok := params[key]; ok && len(values) > 0 { + parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1]))) + seen[key] = struct{}{} + } + } + extraKeys := make([]string, 0, len(params)) + for key := range params { + if _, ok := seen[key]; ok || !isSafeConnectionParamKey(key) { + continue + } + extraKeys = append(extraKeys, key) + } + sort.Strings(extraKeys) + for _, key := range extraKeys { + values := params[key] + if len(values) == 0 { + continue + } + parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1]))) + } - return dsn + return strings.Join(parts, " ") } func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index e2f59e0..8ab24bb 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -6,7 +6,6 @@ import ( "context" "database/sql" "fmt" - "net/url" "strings" "time" @@ -37,17 +36,12 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) { protocol = netName } - timeout := getConnectTimeoutSeconds(config) - tlsMode := resolveMySQLTLSMode(config) - - return fmt.Sprintf( - "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), - ), nil + return buildMySQLCompatibleDSN(config, protocol, address, database), nil } func (m *MariaDB) Connect(config connection.ConnectionConfig) error { - dsn, err := m.getDSN(config) + runConfig := applyMySQLURI(config) + dsn, err := m.getDSN(runConfig) if err != nil { return err } diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index ec73c46..ebec203 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -192,7 +192,7 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf func (m *MongoDB) getURI(config connection.ConnectionConfig) string { if strings.TrimSpace(config.URI) != "" { - return strings.TrimSpace(config.URI) + return mergeConnectionParamsIntoRawURI(config.URI, config.ConnectionParams, "mongodb", "mongodb+srv") } seeds := collectMongoSeeds(config) @@ -257,6 +257,7 @@ func (m *MongoDB) getURI(config connection.ConnectionConfig) string { if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" && !noAuth { params.Set("authMechanism", authMechanism) } + mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams)) // 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现 if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV { diff --git a/internal/db/mongodb_impl_uri_test.go b/internal/db/mongodb_impl_uri_test.go index 020b293..2ff6814 100644 --- a/internal/db/mongodb_impl_uri_test.go +++ b/internal/db/mongodb_impl_uri_test.go @@ -3,6 +3,7 @@ package db import ( + "strings" "testing" "GoNavi-Wails/internal/connection" @@ -37,3 +38,30 @@ func TestApplyMongoURI_ExplicitHostsDoesNotAdoptURIHosts(t *testing.T) { t.Fatalf("expected explicit hosts to stay untouched, got %v", got.Hosts) } } + +func TestMongoURI_MergesConnectionParams(t *testing.T) { + uri := (&MongoDB{}).getURI(connection.ConnectionConfig{ + Host: "mongo.local", + Port: 27017, + Database: "app", + ConnectionParams: "retryWrites=true&readPreference=secondaryPreferred", + }) + + if !strings.Contains(uri, "retryWrites=true") { + t.Fatalf("uri 缺少 retryWrites 参数:%s", uri) + } + if !strings.Contains(uri, "readPreference=secondaryPreferred") { + t.Fatalf("uri 缺少 readPreference 参数:%s", uri) + } +} + +func TestMongoURI_MergesConnectionParamsIntoExistingURI(t *testing.T) { + uri := (&MongoDB{}).getURI(connection.ConnectionConfig{ + URI: "mongodb://mongo.local:27017/app?authSource=admin", + ConnectionParams: "retryWrites=true", + }) + + if !strings.Contains(uri, "authSource=admin") || !strings.Contains(uri, "retryWrites=true") { + t.Fatalf("uri 未合并已有 URI query 与额外参数:%s", uri) + } +} diff --git a/internal/db/mongodb_impl_v1.go b/internal/db/mongodb_impl_v1.go index 17bf146..f3095f9 100644 --- a/internal/db/mongodb_impl_v1.go +++ b/internal/db/mongodb_impl_v1.go @@ -193,7 +193,7 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string { if strings.TrimSpace(config.URI) != "" { - return strings.TrimSpace(config.URI) + return mergeConnectionParamsIntoRawURI(config.URI, config.ConnectionParams, "mongodb", "mongodb+srv") } seeds := collectMongoSeeds(config) @@ -258,6 +258,7 @@ func (m *MongoDBV1) getURI(config connection.ConnectionConfig) string { if authMechanism := strings.TrimSpace(config.MongoAuthMechanism); authMechanism != "" && !noAuth { params.Set("authMechanism", authMechanism) } + mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams)) // 单机模式且未指定副本集名称时,启用 directConnection 避免驱动自动跟随副本集成员发现 if strings.TrimSpace(config.Topology) != "replica" && strings.TrimSpace(config.ReplicaSet) == "" && !config.MongoSRV { diff --git a/internal/db/mysql_connection_params_test.go b/internal/db/mysql_connection_params_test.go new file mode 100644 index 0000000..765d867 --- /dev/null +++ b/internal/db/mysql_connection_params_test.go @@ -0,0 +1,153 @@ +package db + +import ( + "net/url" + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values { + t.Helper() + parts := strings.SplitN(dsn, "?", 2) + if len(parts) != 2 { + t.Fatalf("dsn missing query: %s", dsn) + } + values, err := url.ParseQuery(parts[1]) + if err != nil { + t.Fatalf("parse dsn query: %v", err) + } + return values +} + +func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) { + t.Parallel() + + m := &MySQLDB{} + dsn, err := m.getDSN(connection.ConnectionConfig{ + Host: "db.local", + Port: 3306, + User: "root", + Password: "secret", + Database: "app", + Timeout: 30, + ConnectionParams: "charset=utf8&readTimeout=10&columnsWithAlias=true", + }) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + + query := parseMySQLDSNQueryForTest(t, dsn) + if got := query.Get("charset"); got != "utf8" { + t.Fatalf("charset should be overridden by connectionParams, got=%q", got) + } + if got := query.Get("readTimeout"); got != "10s" { + t.Fatalf("numeric readTimeout should be converted to duration, got=%q", got) + } + if got := query.Get("columnsWithAlias"); got != "true" { + t.Fatalf("driver-specific parameter should be preserved, got=%q", got) + } + if got := query.Get("multiStatements"); got != "true" { + t.Fatalf("default multiStatements should remain enabled, got=%q", got) + } +} + +func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T) { + t.Parallel() + + m := &MySQLDB{} + dsn, err := m.getDSN(connection.ConnectionConfig{ + Host: "192.168.1.1", + Port: 3306, + User: "root", + Database: "app", + ConnectionParams: "useUnicode=true&characterEncoding=utf8&autoReconnect=true&" + + "useSSL=false&verifyServerCertificate=false&useOldAliasMetadataBehavior=true", + }) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + + query := parseMySQLDSNQueryForTest(t, dsn) + if got := query.Get("charset"); got != "utf8" { + t.Fatalf("characterEncoding should map to charset, got=%q", got) + } + if got := query.Get("tls"); got != "false" { + t.Fatalf("useSSL=false should map to tls=false, got=%q", got) + } + for _, forbidden := range []string{ + "useUnicode", + "characterEncoding", + "autoReconnect", + "useSSL", + "verifyServerCertificate", + "useOldAliasMetadataBehavior", + } { + if _, exists := query[forbidden]; exists { + t.Fatalf("JDBC-only parameter %s should not be passed to Go MySQL driver: %v", forbidden, query) + } + } +} + +func TestMySQLDSN_MapsJDBCUTF8EncodingToMySQLCharset(t *testing.T) { + t.Parallel() + + m := &MySQLDB{} + dsn, err := m.getDSN(connection.ConnectionConfig{ + Host: "192.168.1.240", + Port: 3306, + User: "root", + Database: "mkefu_location_dev_local", + URI: "jdbc:mysql://192.168.1.240:3306/mkefu_location_dev_local?" + + "useUnicode=true&characterEncoding=UTF-8&serverTimezone=GMT%2B8", + }) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + + query := parseMySQLDSNQueryForTest(t, dsn) + if got := query.Get("charset"); got != "utf8mb4" { + t.Fatalf("JDBC characterEncoding=UTF-8 should map to MySQL charset utf8mb4, got=%q", got) + } + if got := query.Get("characterEncoding"); got != "" { + t.Fatalf("JDBC characterEncoding should not be passed to Go MySQL driver, got=%q", got) + } + if got := query.Get("serverTimezone"); got != "" { + t.Fatalf("JDBC serverTimezone should not be passed to Go MySQL driver, got=%q", got) + } + if got := query.Get("loc"); got != "Asia%2FShanghai" && got != "Asia/Shanghai" { + t.Fatalf("serverTimezone=GMT+8 should map to loc=Asia/Shanghai, got=%q", got) + } +} + +func TestMySQLDSN_URIParamsAndExplicitParamsPrecedence(t *testing.T) { + t.Parallel() + + m := &MySQLDB{} + dsn, err := m.getDSN(connection.ConnectionConfig{ + Host: "db.local", + Port: 3306, + User: "root", + Database: "app", + URI: "jdbc:mysql://db.local:3306/app?characterEncoding=utf8&timeout=15&topology=replica&useSSL=false", + ConnectionParams: "charset=utf8mb4&timeout=5s&socketTimeout=45000", + }) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + + query := parseMySQLDSNQueryForTest(t, dsn) + if got := query.Get("charset"); got != "utf8mb4" { + t.Fatalf("connectionParams should override URI charset, got=%q", got) + } + if got := query.Get("timeout"); got != "5s" { + t.Fatalf("connectionParams should override URI timeout, got=%q", got) + } + if got := query.Get("readTimeout"); got != "45s" { + t.Fatalf("socketTimeout should map to readTimeout duration, got=%q", got) + } + if _, exists := query["topology"]; exists { + t.Fatalf("internal topology parameter should not be passed to driver: %v", query) + } +} diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 744f948..952a663 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -7,6 +7,7 @@ import ( "fmt" "math" "net/url" + "sort" "strconv" "strings" "time" @@ -26,6 +27,183 @@ type MySQLDB struct { const defaultMySQLPort = 3306 +func parseMySQLCompatibleURI(raw string, allowedSchemes ...string) (*url.URL, bool) { + return parseConnectionURI(raw, allowedSchemes...) +} + +func mysqlConnectionParamsFromText(raw string) url.Values { + return connectionParamsFromText(raw) +} + +func parseMySQLBoolParam(raw string) (bool, bool) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "on": + return true, true + case "0", "false", "no", "off": + return false, true + default: + return false, false + } +} + +func normalizeMySQLDurationParam(raw string, unit time.Duration) string { + text := strings.TrimSpace(raw) + if text == "" { + return text + } + if n, err := strconv.Atoi(text); err == nil && n >= 0 { + return (time.Duration(n) * unit).String() + } + return text +} + +func normalizeMySQLCharsetParam(raw string) string { + text := strings.TrimSpace(raw) + if text == "" { + return "" + } + lower := strings.ToLower(text) + switch lower { + case "utf-8", "utf_8", "unicode": + return "utf8mb4" + case "utf8", "utf8mb4", "latin1", "gbk", "gb2312", "gb18030", "big5", "sjis", "cp932": + return lower + case "iso-8859-1", "iso8859-1", "iso88591": + return "latin1" + default: + return text + } +} + +func normalizeMySQLServerTimezoneParam(raw string) (string, bool) { + text := strings.TrimSpace(raw) + if text == "" { + return "", false + } + compact := strings.ToUpper(strings.ReplaceAll(text, " ", "")) + switch compact { + case "LOCAL": + return "Local", true + case "UTC", "Z", "GMT", "GMT+0", "GMT-0", "GMT+00", "GMT-00", "GMT+00:00", "GMT-00:00", + "UTC+0", "UTC-0", "UTC+00", "UTC-00", "UTC+00:00", "UTC-00:00": + return "UTC", true + case "GMT+8", "GMT+08", "GMT+08:00", "UTC+8", "UTC+08", "UTC+08:00", + "ASIA/SHANGHAI", "PRC", "CTT": + return "Asia/Shanghai", true + } + if strings.Contains(text, "/") { + if _, err := time.LoadLocation(text); err == nil { + return text, true + } + } + return "", false +} + +func mergeMySQLConnectionParam(params url.Values, key string, value string) { + name := strings.TrimSpace(key) + if name == "" { + return + } + lowerName := strings.ToLower(name) + switch lowerName { + case "topology": + return + case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior": + return + case "charset": + if charset := normalizeMySQLCharsetParam(value); charset != "" { + params.Set("charset", charset) + } + return + case "characterencoding": + if charset := normalizeMySQLCharsetParam(value); charset != "" { + params.Set("charset", charset) + } + return + case "servertimezone": + if loc, ok := normalizeMySQLServerTimezoneParam(value); ok { + params.Set("loc", loc) + } + return + case "usessl": + if enabled, ok := parseMySQLBoolParam(value); ok { + if enabled { + params.Set("tls", "true") + } else { + params.Set("tls", "false") + } + } + return + case "verifyservercertificate": + if verified, ok := parseMySQLBoolParam(value); ok && !verified && params.Get("tls") != "false" { + params.Set("tls", "skip-verify") + } + return + case "trustservercertificate": + if trusted, ok := parseMySQLBoolParam(value); ok && trusted && params.Get("tls") != "false" { + params.Set("tls", "skip-verify") + } + return + case "connecttimeout": + params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond)) + return + case "sockettimeout": + params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond)) + return + case "timeout", "readtimeout", "writetimeout": + params.Set(name, normalizeMySQLDurationParam(value, time.Second)) + return + default: + params.Set(name, value) + } +} + +func mergeMySQLConnectionParams(params url.Values, values url.Values) { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + lowerName := strings.ToLower(strings.TrimSpace(key)) + if lowerName == "verifyservercertificate" || lowerName == "trustservercertificate" { + continue + } + for _, value := range values[key] { + mergeMySQLConnectionParam(params, key, value) + } + } + for _, key := range keys { + lowerName := strings.ToLower(strings.TrimSpace(key)) + if lowerName != "verifyservercertificate" && lowerName != "trustservercertificate" { + continue + } + for _, value := range values[key] { + mergeMySQLConnectionParam(params, key, value) + } + } +} + +func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string { + timeout := getConnectTimeoutSeconds(config) + tlsMode := resolveMySQLTLSMode(config) + params := url.Values{} + params.Set("charset", "utf8mb4") + params.Set("parseTime", "True") + params.Set("loc", "Local") + params.Set("timeout", fmt.Sprintf("%ds", timeout)) + params.Set("tls", tlsMode) + params.Set("multiStatements", "true") + if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros"); ok { + mergeMySQLConnectionParams(params, parsed.Query()) + } + mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams)) + return fmt.Sprintf( + "%s:%s@%s(%s)/%s?%s", + config.User, config.Password, protocol, address, database, params.Encode(), + ) +} + func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) { text := strings.TrimSpace(raw) if text == "" { @@ -135,13 +313,8 @@ func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConf if uriText == "" { return config } - lowerURI := strings.ToLower(uriText) - if !strings.HasPrefix(lowerURI, "mysql://") { - return config - } - - parsed, err := url.Parse(uriText) - if err != nil { + parsed, ok := parseMySQLCompatibleURI(uriText, "mysql") + if !ok { return config } @@ -239,13 +412,7 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) { protocol = netName } - timeout := getConnectTimeoutSeconds(config) - tlsMode := resolveMySQLTLSMode(config) - - return fmt.Sprintf( - "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s&multiStatements=true", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), - ), nil + return buildMySQLCompatibleDSN(config, protocol, address, database), nil } func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { diff --git a/internal/db/oracle_dsn_test.go b/internal/db/oracle_dsn_test.go index c4557b6..ff45b10 100644 --- a/internal/db/oracle_dsn_test.go +++ b/internal/db/oracle_dsn_test.go @@ -30,3 +30,28 @@ func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) { t.Fatalf("LOB FETCH = %q, want POST", got) } } + +func TestOracleGetDSNMergesConnectionParams(t *testing.T) { + t.Parallel() + + dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{ + Host: "db.example.com", + Port: 1521, + User: "scott", + Password: "tiger", + Database: "ORCLPDB1", + ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc", + }) + + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("解析 Oracle DSN 失败: %v", err) + } + query := parsed.Query() + if got := query.Get("PREFETCH_ROWS"); got != "5000" { + t.Fatalf("PREFETCH_ROWS = %q, want 5000", got) + } + if got := query.Get("TRACE FILE"); got != "/tmp/go-ora.trc" { + t.Fatalf("TRACE FILE = %q, want /tmp/go-ora.trc", got) + } +} diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 21f53b0..d31bc7a 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -48,6 +48,7 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { q.Set("PREFETCH_ROWS", "10000") // LOB 数据延迟加载,避免大 LOB 列影响普通查询性能 q.Set("LOB FETCH", "POST") + mergeConnectionParamsFromConfig(q, config, "oracle") if encoded := q.Encode(); encoded != "" { u.RawQuery = encoded } diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index e5cf94f..911676e 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -64,6 +64,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string { q := url.Values{} q.Set("sslmode", resolvePostgresSSLMode(config)) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) + mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql") u.RawQuery = q.Encode() return u.String() diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index ed06fd5..aa23626 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -50,6 +50,7 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string { encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config) q.Set("encrypt", encrypt) q.Set("TrustServerCertificate", trustServerCertificate) + mergeConnectionParamsFromConfig(q, config, "sqlserver") u.RawQuery = q.Encode() return u.String() diff --git a/internal/db/tdengine_impl.go b/internal/db/tdengine_impl.go index 0ca37fb..ecf7fc7 100644 --- a/internal/db/tdengine_impl.go +++ b/internal/db/tdengine_impl.go @@ -7,6 +7,7 @@ import ( "database/sql" "fmt" "net" + "net/url" "sort" "strconv" "strings" @@ -42,7 +43,16 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string { } netType := resolveTDengineNet(config) - return fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path) + params := url.Values{} + mergeConnectionParamsFromConfig(params, config, "taos", "taosws", "tdengine") + params.Del("protocol") + params.Del("skip_verify") + query := params.Encode() + dsn := fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path) + if query == "" { + return dsn + } + return dsn + "?" + query } func (t *TDengineDB) Connect(config connection.ConnectionConfig) error { diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go index bb23713..3994e79 100644 --- a/internal/db/vastbase_impl.go +++ b/internal/db/vastbase_impl.go @@ -43,6 +43,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string { q := url.Values{} q.Set("sslmode", resolvePostgresSSLMode(config)) q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) + mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "vastbase") u.RawQuery = q.Encode() return u.String()