mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-22 08:50:17 +08:00
🐛 fix(connection): 收敛数据库连接参数白名单
- MySQL 兼容 JDBC 参数映射并丢弃 allowPublicKeyRetrieval 等无效参数 - 为 PostgreSQL 系、SQL Server、Oracle、达梦、TDengine 接入驱动参数白名单 - 补充连接参数归一化、别名映射和未知参数过滤回归测试
This commit is contained in:
400
internal/db/connection_param_allowlists.go
Normal file
400
internal/db/connection_param_allowlists.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package db
|
||||
|
||||
var postgresConnectionParamNames = newConnectionParamNameMap(
|
||||
"host",
|
||||
"hostaddr",
|
||||
"port",
|
||||
"user",
|
||||
"password",
|
||||
"passfile",
|
||||
"dbname",
|
||||
"database",
|
||||
"options",
|
||||
"application_name",
|
||||
"fallback_application_name",
|
||||
"sslmode",
|
||||
"sslnegotiation",
|
||||
"sslcert",
|
||||
"sslkey",
|
||||
"sslrootcert",
|
||||
"sslsni",
|
||||
"sslinline",
|
||||
"krbsrvname",
|
||||
"krbspn",
|
||||
"connect_timeout",
|
||||
"binary_parameters",
|
||||
"disable_prepared_binary_result",
|
||||
"client_encoding",
|
||||
"datestyle",
|
||||
"tz",
|
||||
"geqo",
|
||||
"target_session_attrs",
|
||||
"load_balance_hosts",
|
||||
"search_path",
|
||||
"work_mem",
|
||||
"statement_timeout",
|
||||
"lock_timeout",
|
||||
"idle_in_transaction_session_timeout",
|
||||
"default_transaction_read_only",
|
||||
"default_transaction_isolation",
|
||||
"TimeZone",
|
||||
)
|
||||
|
||||
var highGoConnectionParamNames = newConnectionParamNameMap(
|
||||
"host",
|
||||
"port",
|
||||
"user",
|
||||
"password",
|
||||
"dbname",
|
||||
"application_name",
|
||||
"fallback_application_name",
|
||||
"sslmode",
|
||||
"sslcert",
|
||||
"sslkey",
|
||||
"sslrootcert",
|
||||
"sslsni",
|
||||
"sslinline",
|
||||
"krbsrvname",
|
||||
"krbspn",
|
||||
"connect_timeout",
|
||||
"binary_parameters",
|
||||
"disable_prepared_binary_result",
|
||||
"search_path",
|
||||
"work_mem",
|
||||
"statement_timeout",
|
||||
"lock_timeout",
|
||||
"idle_in_transaction_session_timeout",
|
||||
"default_transaction_read_only",
|
||||
"default_transaction_isolation",
|
||||
"TimeZone",
|
||||
)
|
||||
|
||||
var kingbaseConnectionParamNames = newConnectionParamNameMap(
|
||||
"host",
|
||||
"port",
|
||||
"user",
|
||||
"password",
|
||||
"dbname",
|
||||
"application_name",
|
||||
"fallback_application_name",
|
||||
"sslmode",
|
||||
"sslcert",
|
||||
"sslkey",
|
||||
"sslrootcert",
|
||||
"krbsrvname",
|
||||
"krbspn",
|
||||
"connect_timeout",
|
||||
"binary_parameters",
|
||||
"disable_prepared_binary_result",
|
||||
"search_path",
|
||||
"work_mem",
|
||||
"statement_timeout",
|
||||
"lock_timeout",
|
||||
"idle_in_transaction_session_timeout",
|
||||
"default_transaction_read_only",
|
||||
"default_transaction_isolation",
|
||||
"TimeZone",
|
||||
)
|
||||
|
||||
var sqlServerConnectionParamNames = makeSQLServerConnectionParamNames()
|
||||
|
||||
func makeSQLServerConnectionParamNames() map[string]string {
|
||||
names := newConnectionParamNameMap(
|
||||
"database",
|
||||
"encrypt",
|
||||
"password",
|
||||
"change password",
|
||||
"user id",
|
||||
"port",
|
||||
"trustservercertificate",
|
||||
"certificate",
|
||||
"servercertificate",
|
||||
"tlsmin",
|
||||
"packet size",
|
||||
"log",
|
||||
"connection timeout",
|
||||
"hostnameincertificate",
|
||||
"keepalive",
|
||||
"serverspn",
|
||||
"workstation id",
|
||||
"app name",
|
||||
"applicationintent",
|
||||
"failoverpartner",
|
||||
"failoverport",
|
||||
"disableretry",
|
||||
"server",
|
||||
"protocol",
|
||||
"dial timeout",
|
||||
"pipe",
|
||||
"multisubnetfailover",
|
||||
"notraceid",
|
||||
"guid conversion",
|
||||
"timezone",
|
||||
"columnencryption",
|
||||
)
|
||||
addConnectionParamAlias(names, "application name", "app name")
|
||||
addConnectionParamAlias(names, "data source", "server")
|
||||
addConnectionParamAlias(names, "network address", "server")
|
||||
addConnectionParamAlias(names, "address", "server")
|
||||
addConnectionParamAlias(names, "addr", "server")
|
||||
addConnectionParamAlias(names, "user", "user id")
|
||||
addConnectionParamAlias(names, "uid", "user id")
|
||||
addConnectionParamAlias(names, "pwd", "password")
|
||||
addConnectionParamAlias(names, "initial catalog", "database")
|
||||
addConnectionParamAlias(names, "column encryption setting", "columnencryption")
|
||||
addConnectionParamAlias(names, "trust server certificate", "trustservercertificate")
|
||||
addConnectionParamAlias(names, "multi subnet failover", "multisubnetfailover")
|
||||
addConnectionParamAlias(names, "application intent", "applicationintent")
|
||||
return names
|
||||
}
|
||||
|
||||
var oracleConnectionParamNames = makeOracleConnectionParamNames()
|
||||
|
||||
func makeOracleConnectionParamNames() map[string]string {
|
||||
names := newConnectionParamNameMap(
|
||||
"CID",
|
||||
"connStr",
|
||||
"SERVER",
|
||||
"SERVICE NAME",
|
||||
"SID",
|
||||
"INSTANCE NAME",
|
||||
"WALLET",
|
||||
"WALLET PASSWORD",
|
||||
"AUTH TYPE",
|
||||
"OS USER",
|
||||
"OS PASS",
|
||||
"OS PASSWORD",
|
||||
"OS HASH",
|
||||
"OS PASSHASH",
|
||||
"OS PASSWORD HASH",
|
||||
"DOMAIN",
|
||||
"AUTH SERV",
|
||||
"ENCRYPTION",
|
||||
"DATA INTEGRITY",
|
||||
"SSL",
|
||||
"SSL VERIFY",
|
||||
"DBA PRIVILEGE",
|
||||
"TIMEOUT",
|
||||
"READ TIMEOUT",
|
||||
"SOCKET TIMEOUT",
|
||||
"CONNECT TIMEOUT",
|
||||
"CONNECTION TIMEOUT",
|
||||
"TRACE FILE",
|
||||
"TRACE DIR",
|
||||
"TRACE FOLDER",
|
||||
"TRACE DIRECTORY",
|
||||
"USE_OOB",
|
||||
"ENABLE_OOB",
|
||||
"ENABLE URGENT DATA TRANSPORT",
|
||||
"PREFETCH_ROWS",
|
||||
"UNIX SOCKET",
|
||||
"PROXY CLIENT NAME",
|
||||
"LOB FETCH",
|
||||
"LANGUAGE",
|
||||
"TERRITORY",
|
||||
"CHARSET",
|
||||
"CLIENT CHARSET",
|
||||
"PROGRAM",
|
||||
"SERVER LOCATION",
|
||||
)
|
||||
addConnectionParamAlias(names, "SERVICE_NAME", "SERVICE NAME")
|
||||
addConnectionParamAlias(names, "INSTANCE_NAME", "INSTANCE NAME")
|
||||
addConnectionParamAlias(names, "WALLET_PASSWORD", "WALLET PASSWORD")
|
||||
addConnectionParamAlias(names, "AUTH_TYPE", "AUTH TYPE")
|
||||
addConnectionParamAlias(names, "AUTH_SERV", "AUTH SERV")
|
||||
addConnectionParamAlias(names, "DATA_INTEGRITY", "DATA INTEGRITY")
|
||||
addConnectionParamAlias(names, "SSL_VERIFY", "SSL VERIFY")
|
||||
addConnectionParamAlias(names, "DBA_PRIVILEGE", "DBA PRIVILEGE")
|
||||
addConnectionParamAlias(names, "READ_TIMEOUT", "READ TIMEOUT")
|
||||
addConnectionParamAlias(names, "SOCKET_TIMEOUT", "SOCKET TIMEOUT")
|
||||
addConnectionParamAlias(names, "CONNECT_TIMEOUT", "CONNECT TIMEOUT")
|
||||
addConnectionParamAlias(names, "CONNECTION_TIMEOUT", "CONNECTION TIMEOUT")
|
||||
addConnectionParamAlias(names, "TRACE_FILE", "TRACE FILE")
|
||||
addConnectionParamAlias(names, "TRACE_DIR", "TRACE DIR")
|
||||
addConnectionParamAlias(names, "TRACE_FOLDER", "TRACE FOLDER")
|
||||
addConnectionParamAlias(names, "TRACE_DIRECTORY", "TRACE DIRECTORY")
|
||||
addConnectionParamAlias(names, "UNIX_SOCKET", "UNIX SOCKET")
|
||||
addConnectionParamAlias(names, "PROXY_CLIENT_NAME", "PROXY CLIENT NAME")
|
||||
addConnectionParamAlias(names, "LOB_FETCH", "LOB FETCH")
|
||||
addConnectionParamAlias(names, "CLIENT_CHARSET", "CLIENT CHARSET")
|
||||
addConnectionParamAlias(names, "SERVER_LOCATION", "SERVER LOCATION")
|
||||
return names
|
||||
}
|
||||
|
||||
var damengConnectionParamNames = makeDamengConnectionParamNames()
|
||||
|
||||
func makeDamengConnectionParamNames() map[string]string {
|
||||
names := newConnectionParamNameMap(
|
||||
"timeZone",
|
||||
"enRsCache",
|
||||
"rsCacheSize",
|
||||
"rsRefreshFreq",
|
||||
"loginPrimary",
|
||||
"loginMode",
|
||||
"loginStatus",
|
||||
"loginDscCtrl",
|
||||
"switchTimes",
|
||||
"switchInterval",
|
||||
"epSelector",
|
||||
"primaryKey",
|
||||
"keywords",
|
||||
"compress",
|
||||
"compressId",
|
||||
"loginEncrypt",
|
||||
"communicationEncrypt",
|
||||
"direct",
|
||||
"dec2double",
|
||||
"rwSeparate",
|
||||
"rwPercent",
|
||||
"rwAutoDistribute",
|
||||
"compatibleMode",
|
||||
"comOra",
|
||||
"cipherPath",
|
||||
"doSwitch",
|
||||
"driverReconnect",
|
||||
"cluster",
|
||||
"language",
|
||||
"dbAliveCheckFreq",
|
||||
"rwStandbyRecoverTime",
|
||||
"logLevel",
|
||||
"logDir",
|
||||
"logBufferPoolSize",
|
||||
"logBufferSize",
|
||||
"logFlusherQueueSize",
|
||||
"logFlushFreq",
|
||||
"statEnable",
|
||||
"statDir",
|
||||
"statFlushFreq",
|
||||
"statHighFreqSqlCount",
|
||||
"statSlowSqlCount",
|
||||
"statSqlMaxCount",
|
||||
"statSqlRemoveMode",
|
||||
"addressRemap",
|
||||
"userRemap",
|
||||
"connectTimeout",
|
||||
"loginCertificate",
|
||||
"url",
|
||||
"host",
|
||||
"port",
|
||||
"user",
|
||||
"password",
|
||||
"dialName",
|
||||
"rwStandby",
|
||||
"isCompress",
|
||||
"rwHA",
|
||||
"rwIgnoreSql",
|
||||
"appName",
|
||||
"osName",
|
||||
"mppLocal",
|
||||
"socketTimeout",
|
||||
"sessionTimeout",
|
||||
"continueBatchOnError",
|
||||
"batchAllowMaxErrors",
|
||||
"escapeProcess",
|
||||
"autoCommit",
|
||||
"maxRows",
|
||||
"rowPrefetch",
|
||||
"bufPrefetch",
|
||||
"LobMode",
|
||||
"StmtPoolSize",
|
||||
"AlwayseAllowCommit",
|
||||
"batchType",
|
||||
"batchNotOnCall",
|
||||
"isBdtaRS",
|
||||
"clobAsString",
|
||||
"sslCertPath",
|
||||
"sslKeyPath",
|
||||
"sslFilesPath",
|
||||
"kerberosLoginConfPath",
|
||||
"uKeyName",
|
||||
"uKeyPin",
|
||||
"columnNameUpperCase",
|
||||
"columnNameCase",
|
||||
"databaseProductName",
|
||||
"osAuthType",
|
||||
"schema",
|
||||
"catalog",
|
||||
"serverOption",
|
||||
"clobToBytes",
|
||||
"localTimezone",
|
||||
"sessEncode",
|
||||
"svcConfPath",
|
||||
"confPath",
|
||||
)
|
||||
addConnectionParamAlias(names, "ADDRESS_REMAP", "addressRemap")
|
||||
addConnectionParamAlias(names, "ALWAYS_ALLOW_COMMIT", "AlwayseAllowCommit")
|
||||
addConnectionParamAlias(names, "APP_NAME", "appName")
|
||||
addConnectionParamAlias(names, "AUTO_COMMIT", "autoCommit")
|
||||
addConnectionParamAlias(names, "BATCH_ALLOW_MAX_ERRORS", "batchAllowMaxErrors")
|
||||
addConnectionParamAlias(names, "BATCH_CONTINUE_ON_ERROR", "continueBatchOnError")
|
||||
addConnectionParamAlias(names, "CONTINUE_BATCH_ON_ERROR", "continueBatchOnError")
|
||||
addConnectionParamAlias(names, "BATCH_NOT_ON_CALL", "batchNotOnCall")
|
||||
addConnectionParamAlias(names, "BATCH_TYPE", "batchType")
|
||||
addConnectionParamAlias(names, "BUF_PREFETCH", "bufPrefetch")
|
||||
addConnectionParamAlias(names, "CIPHER_PATH", "cipherPath")
|
||||
addConnectionParamAlias(names, "COLUMN_NAME_UPPER_CASE", "columnNameUpperCase")
|
||||
addConnectionParamAlias(names, "COLUMN_NAME_CASE", "columnNameCase")
|
||||
addConnectionParamAlias(names, "COMPATIBLE_MODE", "compatibleMode")
|
||||
addConnectionParamAlias(names, "COMPRESS_MSG", "compress")
|
||||
addConnectionParamAlias(names, "COMPRESS_ID", "compressId")
|
||||
addConnectionParamAlias(names, "CONNECT_TIMEOUT", "connectTimeout")
|
||||
addConnectionParamAlias(names, "DO_SWITCH", "doSwitch")
|
||||
addConnectionParamAlias(names, "AUTO_RECONNECT", "doSwitch")
|
||||
addConnectionParamAlias(names, "ENABLE_RS_CACHE", "enRsCache")
|
||||
addConnectionParamAlias(names, "EP_SELECTION", "epSelector")
|
||||
addConnectionParamAlias(names, "ESCAPE_PROCESS", "escapeProcess")
|
||||
addConnectionParamAlias(names, "IS_BDTA_RS", "isBdtaRS")
|
||||
addConnectionParamAlias(names, "KEY_WORDS", "keywords")
|
||||
addConnectionParamAlias(names, "LOB_MODE", "LobMode")
|
||||
addConnectionParamAlias(names, "LOG_BUFFER_SIZE", "logBufferSize")
|
||||
addConnectionParamAlias(names, "LOG_DIR", "logDir")
|
||||
addConnectionParamAlias(names, "LOG_FLUSH_FREQ", "logFlushFreq")
|
||||
addConnectionParamAlias(names, "LOG_FLUSHER_QUEUESIZE", "logFlusherQueueSize")
|
||||
addConnectionParamAlias(names, "LOG_LEVEL", "logLevel")
|
||||
addConnectionParamAlias(names, "LOGIN_DSC_CTRL", "loginDscCtrl")
|
||||
addConnectionParamAlias(names, "LOGIN_ENCRYPT", "loginEncrypt")
|
||||
addConnectionParamAlias(names, "LOGIN_MODE", "loginMode")
|
||||
addConnectionParamAlias(names, "LOGIN_STATUS", "loginStatus")
|
||||
addConnectionParamAlias(names, "MAX_ROWS", "maxRows")
|
||||
addConnectionParamAlias(names, "MPP_LOCAL", "mppLocal")
|
||||
addConnectionParamAlias(names, "OS_NAME", "osName")
|
||||
addConnectionParamAlias(names, "RS_CACHE_SIZE", "rsCacheSize")
|
||||
addConnectionParamAlias(names, "RS_REFRESH_FREQ", "rsRefreshFreq")
|
||||
addConnectionParamAlias(names, "RW_HA", "rwHA")
|
||||
addConnectionParamAlias(names, "RW_IGNORE_SQL", "rwIgnoreSql")
|
||||
addConnectionParamAlias(names, "RW_PERCENT", "rwPercent")
|
||||
addConnectionParamAlias(names, "RW_SEPARATE", "rwSeparate")
|
||||
addConnectionParamAlias(names, "RW_STANDBY_RECOVER_TIME", "rwStandbyRecoverTime")
|
||||
addConnectionParamAlias(names, "SESS_ENCODE", "sessEncode")
|
||||
addConnectionParamAlias(names, "SESSION_TIMEOUT", "sessionTimeout")
|
||||
addConnectionParamAlias(names, "SOCKET_TIMEOUT", "socketTimeout")
|
||||
addConnectionParamAlias(names, "SSL_CERT_PATH", "sslCertPath")
|
||||
addConnectionParamAlias(names, "SSL_FILES_PATH", "sslFilesPath")
|
||||
addConnectionParamAlias(names, "SSL_KEY_PATH", "sslKeyPath")
|
||||
addConnectionParamAlias(names, "STAT_DIR", "statDir")
|
||||
addConnectionParamAlias(names, "STAT_ENABLE", "statEnable")
|
||||
addConnectionParamAlias(names, "STAT_FLUSH_FREQ", "statFlushFreq")
|
||||
addConnectionParamAlias(names, "STAT_HIGH_FREQ_SQL_COUNT", "statHighFreqSqlCount")
|
||||
addConnectionParamAlias(names, "STAT_SLOW_SQL_COUNT", "statSlowSqlCount")
|
||||
addConnectionParamAlias(names, "STAT_SQL_MAX_COUNT", "statSqlMaxCount")
|
||||
addConnectionParamAlias(names, "STAT_SQL_REMOVE_MODE", "statSqlRemoveMode")
|
||||
addConnectionParamAlias(names, "SWITCH_INTERVAL", "switchInterval")
|
||||
addConnectionParamAlias(names, "SWITCH_TIME", "switchTimes")
|
||||
addConnectionParamAlias(names, "SWITCH_TIMES", "switchTimes")
|
||||
addConnectionParamAlias(names, "TIME_ZONE", "timeZone")
|
||||
addConnectionParamAlias(names, "USER_REMAP", "userRemap")
|
||||
addConnectionParamAlias(names, "SERVER_OPTION", "serverOption")
|
||||
addConnectionParamAlias(names, "CLOB_TO_BYTES", "clobToBytes")
|
||||
return names
|
||||
}
|
||||
|
||||
var tdengineConnectionParamNames = newConnectionParamNameMap(
|
||||
"interpolateParams",
|
||||
"token",
|
||||
"enableCompression",
|
||||
"readTimeout",
|
||||
"writeTimeout",
|
||||
"timezone",
|
||||
"bearerToken",
|
||||
"totpCode",
|
||||
)
|
||||
@@ -77,11 +77,69 @@ func mergeConnectionParamValues(params url.Values, values url.Values) {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeConnectionParamName(name string) string {
|
||||
return strings.ToLower(strings.Join(strings.Fields(strings.TrimSpace(name)), " "))
|
||||
}
|
||||
|
||||
func newConnectionParamNameMap(names ...string) map[string]string {
|
||||
result := make(map[string]string, len(names))
|
||||
for _, name := range names {
|
||||
addConnectionParamName(result, name)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func addConnectionParamName(names map[string]string, canonical string) {
|
||||
key := normalizeConnectionParamName(canonical)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
names[key] = canonical
|
||||
}
|
||||
|
||||
func addConnectionParamAlias(names map[string]string, alias string, canonical string) {
|
||||
key := normalizeConnectionParamName(alias)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if existing, ok := names[normalizeConnectionParamName(canonical)]; ok {
|
||||
canonical = existing
|
||||
}
|
||||
names[key] = canonical
|
||||
}
|
||||
|
||||
func mergeConnectionParamValuesWithAllowlist(params url.Values, values url.Values, canonicalNames map[string]string) {
|
||||
if len(values) == 0 || len(canonicalNames) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
if strings.TrimSpace(key) != "" {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
canonical, ok := canonicalNames[normalizeConnectionParamName(key)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, value := range values[key] {
|
||||
params.Set(canonical, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeConnectionParamsFromConfig(params url.Values, config connection.ConnectionConfig, allowedSchemes ...string) {
|
||||
mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, allowedSchemes...))
|
||||
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
|
||||
}
|
||||
|
||||
func mergeConnectionParamsFromConfigWithAllowlist(params url.Values, config connection.ConnectionConfig, canonicalNames map[string]string, allowedSchemes ...string) {
|
||||
mergeConnectionParamValuesWithAllowlist(params, connectionParamsFromURI(config.URI, allowedSchemes...), canonicalNames)
|
||||
mergeConnectionParamValuesWithAllowlist(params, connectionParamsFromText(config.ConnectionParams), canonicalNames)
|
||||
}
|
||||
|
||||
func mergeConnectionParamsIntoRawURI(raw string, connectionParams string, allowedSchemes ...string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
|
||||
94
internal/db/connection_params_test.go
Normal file
94
internal/db/connection_params_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestMergeConnectionParamsFromConfigWithAllowlistCanonicalizesAndFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := url.Values{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
URI: "postgres://u:p@db.local/app?application_name=from-uri&unknown_uri=bad",
|
||||
ConnectionParams: "Application_Name=from-config&statement_timeout=3000&timezone=Asia%2FShanghai&unknown_config=bad",
|
||||
}
|
||||
|
||||
mergeConnectionParamsFromConfigWithAllowlist(params, cfg, postgresConnectionParamNames, "postgres")
|
||||
|
||||
if got := params.Get("application_name"); got != "from-config" {
|
||||
t.Fatalf("application_name = %q, want from-config", got)
|
||||
}
|
||||
if got := params.Get("statement_timeout"); got != "3000" {
|
||||
t.Fatalf("statement_timeout = %q, want 3000", got)
|
||||
}
|
||||
if got := params.Get("TimeZone"); got != "Asia/Shanghai" {
|
||||
t.Fatalf("TimeZone = %q, want Asia/Shanghai", got)
|
||||
}
|
||||
if got := params.Get("unknown_uri"); got != "" {
|
||||
t.Fatalf("unknown_uri should be filtered, got %q", got)
|
||||
}
|
||||
if got := params.Get("unknown_config"); got != "" {
|
||||
t.Fatalf("unknown_config should be filtered, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLServerConnectionParamAllowlistMapsADOSynonyms(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := url.Values{}
|
||||
mergeConnectionParamValuesWithAllowlist(params, url.Values{
|
||||
"Application Name": []string{"GoNavi"},
|
||||
"Initial Catalog": []string{"appdb"},
|
||||
"UID": []string{"sa"},
|
||||
"Trust Server Certificate": []string{"true"},
|
||||
"Column Encryption Setting": []string{"Enabled"},
|
||||
"ignored": []string{"bad"},
|
||||
}, sqlServerConnectionParamNames)
|
||||
|
||||
if got := params.Get("app name"); got != "GoNavi" {
|
||||
t.Fatalf("app name = %q, want GoNavi", got)
|
||||
}
|
||||
if got := params.Get("database"); got != "appdb" {
|
||||
t.Fatalf("database = %q, want appdb", got)
|
||||
}
|
||||
if got := params.Get("user id"); got != "sa" {
|
||||
t.Fatalf("user id = %q, want sa", got)
|
||||
}
|
||||
if got := params.Get("trustservercertificate"); got != "true" {
|
||||
t.Fatalf("trustservercertificate = %q, want true", got)
|
||||
}
|
||||
if got := params.Get("columnencryption"); got != "Enabled" {
|
||||
t.Fatalf("columnencryption = %q, want Enabled", got)
|
||||
}
|
||||
if got := params.Get("ignored"); got != "" {
|
||||
t.Fatalf("ignored should be filtered, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDamengConnectionParamAllowlistMapsUppercaseAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := url.Values{}
|
||||
mergeConnectionParamValuesWithAllowlist(params, url.Values{
|
||||
"SSL_CERT_PATH": []string{"/cert.pem"},
|
||||
"SSL_KEY_PATH": []string{"/key.pem"},
|
||||
"CONNECT_TIMEOUT": []string{"5000"},
|
||||
"unknown": []string{"bad"},
|
||||
}, damengConnectionParamNames)
|
||||
|
||||
if got := params.Get("sslCertPath"); got != "/cert.pem" {
|
||||
t.Fatalf("sslCertPath = %q, want /cert.pem", got)
|
||||
}
|
||||
if got := params.Get("sslKeyPath"); got != "/key.pem" {
|
||||
t.Fatalf("sslKeyPath = %q, want /key.pem", got)
|
||||
}
|
||||
if got := params.Get("connectTimeout"); got != "5000" {
|
||||
t.Fatalf("connectTimeout = %q, want 5000", got)
|
||||
}
|
||||
if got := params.Get("unknown"); got != "" {
|
||||
t.Fatalf("unknown should be filtered, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -37,13 +37,13 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
if config.UseSSL {
|
||||
if certPath := strings.TrimSpace(config.SSLCertPath); certPath != "" {
|
||||
q.Set("SSL_CERT_PATH", certPath)
|
||||
q.Set("sslCertPath", certPath)
|
||||
}
|
||||
if keyPath := strings.TrimSpace(config.SSLKeyPath); keyPath != "" {
|
||||
q.Set("SSL_KEY_PATH", keyPath)
|
||||
q.Set("sslKeyPath", keyPath)
|
||||
}
|
||||
}
|
||||
mergeConnectionParamsFromConfig(q, config, "dm", "dameng")
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, damengConnectionParamNames, "dm", "dameng")
|
||||
|
||||
// 当前达梦 Go 驱动使用字符串切分解析 DSN,认证信息不会做 URL 反解码。
|
||||
// 密码保持原样传入,避免 p%40ss 这类转义文本被当作真实密码登录。
|
||||
|
||||
@@ -64,7 +64,7 @@ func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
ConnectionParams: "application_name=GoNavi&connect_timeout=9",
|
||||
ConnectionParams: "application_name=GoNavi&connect_timeout=9&statement_timeout=3000&allowPublicKeyRetrieval=true",
|
||||
}
|
||||
|
||||
dsn := p.getDSN(cfg)
|
||||
@@ -79,6 +79,12 @@ func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
|
||||
if got := query.Get("connect_timeout"); got != "9" {
|
||||
t.Fatalf("connect_timeout = %q, want 9", got)
|
||||
}
|
||||
if got := query.Get("statement_timeout"); got != "3000" {
|
||||
t.Fatalf("statement_timeout = %q, want 3000", got)
|
||||
}
|
||||
if got := query.Get("allowPublicKeyRetrieval"); got != "" {
|
||||
t.Fatalf("unsupported postgres param should be filtered, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
|
||||
@@ -184,11 +190,35 @@ func TestDamengDSN_AppendsSSLCertAndKeyParams(t *testing.T) {
|
||||
}
|
||||
|
||||
dsn := d.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "SSL_CERT_PATH=") {
|
||||
t.Fatalf("dsn 缺少 SSL_CERT_PATH 参数:%s", dsn)
|
||||
if !strings.Contains(dsn, "sslCertPath=") {
|
||||
t.Fatalf("dsn 缺少 sslCertPath 参数:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "SSL_KEY_PATH=") {
|
||||
t.Fatalf("dsn 缺少 SSL_KEY_PATH 参数:%s", dsn)
|
||||
if !strings.Contains(dsn, "sslKeyPath=") {
|
||||
t.Fatalf("dsn 缺少 sslKeyPath 参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDamengDSN_FiltersUnsupportedConnectionParams(t *testing.T) {
|
||||
d := &DamengDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "dameng",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5236,
|
||||
User: "SYSDBA",
|
||||
Password: "pass",
|
||||
Database: "DBName",
|
||||
ConnectionParams: "SSL_CERT_PATH=/cert.pem&CONNECT_TIMEOUT=5000&unknown=bad",
|
||||
}
|
||||
|
||||
dsn := d.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "sslCertPath=%2Fcert.pem") {
|
||||
t.Fatalf("dsn 缺少规范化 sslCertPath 参数:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "connectTimeout=5000") {
|
||||
t.Fatalf("dsn 缺少规范化 connectTimeout 参数:%s", dsn)
|
||||
}
|
||||
if strings.Contains(dsn, "SSL_CERT_PATH") || strings.Contains(dsn, "unknown=bad") {
|
||||
t.Fatalf("dsn 不应透传达梦未知或非规范参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,7 +248,7 @@ func TestKingbaseDSN_MergesConnectionParams(t *testing.T) {
|
||||
User: "system",
|
||||
Password: "pass",
|
||||
Database: "TEST",
|
||||
ConnectionParams: "application_name=GoNavi&connect_timeout=12",
|
||||
ConnectionParams: "application_name=GoNavi&connect_timeout=12&statement_timeout=3000&unknown=bad",
|
||||
}
|
||||
|
||||
dsn := k.getDSN(cfg)
|
||||
@@ -228,6 +258,12 @@ func TestKingbaseDSN_MergesConnectionParams(t *testing.T) {
|
||||
if !strings.Contains(dsn, "connect_timeout=12") {
|
||||
t.Fatalf("dsn 缺少自定义 connect_timeout:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "statement_timeout=3000") {
|
||||
t.Fatalf("dsn 缺少允许的 runtime 参数:%s", dsn)
|
||||
}
|
||||
if strings.Contains(dsn, "unknown=bad") {
|
||||
t.Fatalf("dsn 不应透传未知 Kingbase 参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
|
||||
@@ -275,13 +311,19 @@ func TestTDengineDSN_MergesConnectionParams(t *testing.T) {
|
||||
User: "root",
|
||||
Password: "taosdata",
|
||||
Database: "power",
|
||||
ConnectionParams: "timezone=Asia%2FShanghai&protocol=wss",
|
||||
ConnectionParams: "timezone=Asia%2FShanghai&protocol=wss&readTimeout=10s&unknown=bad",
|
||||
}
|
||||
|
||||
dsn := td.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "?timezone=Asia%2FShanghai") {
|
||||
if !strings.Contains(dsn, "timezone=Asia%2FShanghai") {
|
||||
t.Fatalf("tdengine dsn 缺少自定义参数或错误透传 protocol:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "readTimeout=10s") {
|
||||
t.Fatalf("tdengine dsn 缺少 readTimeout 参数:%s", dsn)
|
||||
}
|
||||
if strings.Contains(dsn, "protocol=wss") || strings.Contains(dsn, "unknown=bad") {
|
||||
t.Fatalf("tdengine dsn 不应透传协议控制项或未知参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLServerDSN_EncryptMapping(t *testing.T) {
|
||||
@@ -315,7 +357,7 @@ func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
|
||||
User: "sa",
|
||||
Password: "pass",
|
||||
Database: "master",
|
||||
ConnectionParams: "app name=GoNavi&packet size=32767",
|
||||
ConnectionParams: "Application Name=GoNavi&Initial Catalog=appdb&packet size=32767&unknown=bad",
|
||||
}
|
||||
|
||||
dsn := s.getDSN(cfg)
|
||||
@@ -327,9 +369,15 @@ func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
|
||||
if got := query.Get("app name"); got != "GoNavi" {
|
||||
t.Fatalf("app name = %q, want GoNavi", got)
|
||||
}
|
||||
if got := query.Get("database"); got != "appdb" {
|
||||
t.Fatalf("database = %q, want appdb", got)
|
||||
}
|
||||
if got := query.Get("packet size"); got != "32767" {
|
||||
t.Fatalf("packet size = %q, want 32767", got)
|
||||
}
|
||||
if got := query.Get("unknown"); got != "" {
|
||||
t.Fatalf("unknown should be filtered, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
|
||||
|
||||
@@ -44,7 +44,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "highgo")
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, highGoConnectionParamNames, "postgres", "postgresql", "highgo")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
@@ -71,7 +71,7 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
params.Set("dbname", config.Database)
|
||||
params.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(params, config, "kingbase")
|
||||
mergeConnectionParamsFromConfigWithAllowlist(params, config, kingbaseConnectionParamNames, "kingbase")
|
||||
|
||||
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"}
|
||||
seen := make(map[string]struct{}, len(params))
|
||||
|
||||
@@ -64,7 +64,7 @@ func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T
|
||||
User: "root",
|
||||
Database: "app",
|
||||
ConnectionParams: "useUnicode=true&characterEncoding=utf8&autoReconnect=true&" +
|
||||
"useSSL=false&verifyServerCertificate=false&useOldAliasMetadataBehavior=true",
|
||||
"allowPublicKeyRetrieval=true&useSSL=false&verifyServerCertificate=false&useOldAliasMetadataBehavior=true",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
@@ -81,6 +81,7 @@ func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T
|
||||
"useUnicode",
|
||||
"characterEncoding",
|
||||
"autoReconnect",
|
||||
"allowPublicKeyRetrieval",
|
||||
"useSSL",
|
||||
"verifyServerCertificate",
|
||||
"useOldAliasMetadataBehavior",
|
||||
@@ -122,6 +123,153 @@ func TestMySQLDSN_MapsJDBCUTF8EncodingToMySQLCharset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_DropsJDBCAllowPublicKeyRetrievalParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "db.local",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "app",
|
||||
URI: "jdbc:mysql://db.local:3306/app?allowPublicKeyRetrieval=true&useSSL=false",
|
||||
ConnectionParams: "allowPublicKeyRetrieval=true&readtimeout=10&writetimeout=11",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
if _, exists := query["allowPublicKeyRetrieval"]; exists {
|
||||
t.Fatalf("JDBC allowPublicKeyRetrieval should not be passed to Go MySQL driver: %v", query)
|
||||
}
|
||||
if got := query.Get("tls"); got != "false" {
|
||||
t.Fatalf("useSSL=false should still map to tls=false, got=%q", got)
|
||||
}
|
||||
if got := query.Get("readTimeout"); got != "10s" {
|
||||
t.Fatalf("readtimeout should canonicalize to readTimeout duration, got=%q", got)
|
||||
}
|
||||
if got := query.Get("writeTimeout"); got != "11s" {
|
||||
t.Fatalf("writetimeout should canonicalize to writeTimeout duration, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_PreservesSupportedGoDriverParamsAndDropsUnknownParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "db.local",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "app",
|
||||
ConnectionParams: strings.Join([]string{
|
||||
"allowAllFiles=true",
|
||||
"allowCleartextPasswords=true",
|
||||
"allowFallbackToPlaintext=true",
|
||||
"allowNativePasswords=false",
|
||||
"allowOldPasswords=true",
|
||||
"checkConnLiveness=false",
|
||||
"clientFoundRows=true",
|
||||
"charset=latin1",
|
||||
"collation=utf8mb4_unicode_ci",
|
||||
"columnsWithAlias=true",
|
||||
"compress=true",
|
||||
"connectionAttributes=program_name:GoNavi",
|
||||
"interpolateParams=true",
|
||||
"loc=UTC",
|
||||
"maxAllowedPacket=1048576",
|
||||
"multiStatements=false",
|
||||
"parseTime=false",
|
||||
"readtimeout=7",
|
||||
"rejectReadOnly=true",
|
||||
"serverPubKey=testKey",
|
||||
"timeTruncate=2",
|
||||
"timeout=8",
|
||||
"tls=preferred",
|
||||
"writetimeout=9",
|
||||
"strict=true",
|
||||
"unsupportedJdbcParam=true",
|
||||
}, "&"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
want := map[string]string{
|
||||
"allowAllFiles": "true",
|
||||
"allowCleartextPasswords": "true",
|
||||
"allowFallbackToPlaintext": "true",
|
||||
"allowNativePasswords": "false",
|
||||
"allowOldPasswords": "true",
|
||||
"checkConnLiveness": "false",
|
||||
"clientFoundRows": "true",
|
||||
"charset": "latin1",
|
||||
"collation": "utf8mb4_unicode_ci",
|
||||
"columnsWithAlias": "true",
|
||||
"compress": "true",
|
||||
"connectionAttributes": "program_name:GoNavi",
|
||||
"interpolateParams": "true",
|
||||
"loc": "UTC",
|
||||
"maxAllowedPacket": "1048576",
|
||||
"multiStatements": "false",
|
||||
"parseTime": "false",
|
||||
"readTimeout": "7s",
|
||||
"rejectReadOnly": "true",
|
||||
"serverPubKey": "testKey",
|
||||
"timeTruncate": "2s",
|
||||
"timeout": "8s",
|
||||
"tls": "preferred",
|
||||
"writeTimeout": "9s",
|
||||
}
|
||||
for key, value := range want {
|
||||
if got := query.Get(key); got != value {
|
||||
t.Fatalf("%s should be %q, got %q; query=%v", key, value, got, query)
|
||||
}
|
||||
}
|
||||
for _, forbidden := range []string{"strict", "unsupportedJdbcParam"} {
|
||||
if _, exists := query[forbidden]; exists {
|
||||
t.Fatalf("unsupported parameter %s should not be passed to Go MySQL driver: %v", forbidden, query)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_MapsAdditionalJDBCAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "db.local",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "app",
|
||||
ConnectionParams: strings.Join([]string{
|
||||
"sslMode=required",
|
||||
"allowMultiQueries=false",
|
||||
"useCompression=true",
|
||||
"connectionCollation=utf8mb4_bin",
|
||||
}, "&"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
if got := query.Get("tls"); got != "true" {
|
||||
t.Fatalf("sslMode=required should map to tls=true, got=%q", got)
|
||||
}
|
||||
if got := query.Get("multiStatements"); got != "false" {
|
||||
t.Fatalf("allowMultiQueries=false should map to multiStatements=false, got=%q", got)
|
||||
}
|
||||
if got := query.Get("compress"); got != "true" {
|
||||
t.Fatalf("useCompression=true should map to compress=true, got=%q", got)
|
||||
}
|
||||
if got := query.Get("collation"); got != "utf8mb4_bin" {
|
||||
t.Fatalf("connectionCollation should map to collation, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -99,6 +99,73 @@ func normalizeMySQLServerTimezoneParam(raw string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var mysqlSupportedDriverParamNames = map[string]string{
|
||||
"allowallfiles": "allowAllFiles",
|
||||
"allowcleartextpasswords": "allowCleartextPasswords",
|
||||
"allowfallbacktoplaintext": "allowFallbackToPlaintext",
|
||||
"allownativepasswords": "allowNativePasswords",
|
||||
"allowoldpasswords": "allowOldPasswords",
|
||||
"checkconnliveness": "checkConnLiveness",
|
||||
"clientfoundrows": "clientFoundRows",
|
||||
"charset": "charset",
|
||||
"collation": "collation",
|
||||
"columnswithalias": "columnsWithAlias",
|
||||
"compress": "compress",
|
||||
"connectionattributes": "connectionAttributes",
|
||||
"interpolateparams": "interpolateParams",
|
||||
"loc": "loc",
|
||||
"maxallowedpacket": "maxAllowedPacket",
|
||||
"multistatements": "multiStatements",
|
||||
"parsetime": "parseTime",
|
||||
"readtimeout": "readTimeout",
|
||||
"rejectreadonly": "rejectReadOnly",
|
||||
"serverpubkey": "serverPubKey",
|
||||
"timetruncate": "timeTruncate",
|
||||
"timeout": "timeout",
|
||||
"tls": "tls",
|
||||
"writetimeout": "writeTimeout",
|
||||
}
|
||||
|
||||
var mysqlBoolDriverParamNames = map[string]struct{}{
|
||||
"allowAllFiles": {},
|
||||
"allowCleartextPasswords": {},
|
||||
"allowFallbackToPlaintext": {},
|
||||
"allowNativePasswords": {},
|
||||
"allowOldPasswords": {},
|
||||
"checkConnLiveness": {},
|
||||
"clientFoundRows": {},
|
||||
"columnsWithAlias": {},
|
||||
"compress": {},
|
||||
"interpolateParams": {},
|
||||
"multiStatements": {},
|
||||
"parseTime": {},
|
||||
"rejectReadOnly": {},
|
||||
}
|
||||
|
||||
func canonicalMySQLDriverParamName(name string) (string, bool) {
|
||||
canonical, ok := mysqlSupportedDriverParamNames[strings.ToLower(strings.TrimSpace(name))]
|
||||
return canonical, ok
|
||||
}
|
||||
|
||||
func setMySQLDriverParam(params url.Values, name string, value string) {
|
||||
switch name {
|
||||
case "charset":
|
||||
if charset := normalizeMySQLCharsetParam(value); charset != "" {
|
||||
params.Set("charset", charset)
|
||||
}
|
||||
case "timeout", "readTimeout", "writeTimeout", "timeTruncate":
|
||||
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
|
||||
default:
|
||||
if _, ok := mysqlBoolDriverParamNames[name]; ok {
|
||||
if enabled, ok := parseMySQLBoolParam(value); ok {
|
||||
params.Set(name, strconv.FormatBool(enabled))
|
||||
return
|
||||
}
|
||||
}
|
||||
params.Set(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeMySQLConnectionParam(params url.Values, key string, value string) {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
@@ -108,12 +175,7 @@ func mergeMySQLConnectionParam(params url.Values, key string, value string) {
|
||||
switch lowerName {
|
||||
case "topology":
|
||||
return
|
||||
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior":
|
||||
return
|
||||
case "charset":
|
||||
if charset := normalizeMySQLCharsetParam(value); charset != "" {
|
||||
params.Set("charset", charset)
|
||||
}
|
||||
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior", "allowpublickeyretrieval":
|
||||
return
|
||||
case "characterencoding":
|
||||
if charset := normalizeMySQLCharsetParam(value); charset != "" {
|
||||
@@ -144,17 +206,41 @@ func mergeMySQLConnectionParam(params url.Values, key string, value string) {
|
||||
params.Set("tls", "skip-verify")
|
||||
}
|
||||
return
|
||||
case "sslmode":
|
||||
switch normalizeSSLModeValue(value) {
|
||||
case sslModeDisable:
|
||||
params.Set("tls", "false")
|
||||
case sslModeRequired:
|
||||
params.Set("tls", "true")
|
||||
case sslModeSkipVerify:
|
||||
params.Set("tls", "skip-verify")
|
||||
default:
|
||||
params.Set("tls", "preferred")
|
||||
}
|
||||
return
|
||||
case "connecttimeout":
|
||||
params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond))
|
||||
return
|
||||
case "sockettimeout":
|
||||
params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond))
|
||||
return
|
||||
case "timeout", "readtimeout", "writetimeout":
|
||||
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
|
||||
case "allowmultiqueries":
|
||||
if enabled, ok := parseMySQLBoolParam(value); ok {
|
||||
params.Set("multiStatements", strconv.FormatBool(enabled))
|
||||
}
|
||||
return
|
||||
case "usecompression":
|
||||
if enabled, ok := parseMySQLBoolParam(value); ok {
|
||||
params.Set("compress", strconv.FormatBool(enabled))
|
||||
}
|
||||
return
|
||||
case "connectioncollation":
|
||||
params.Set("collation", value)
|
||||
return
|
||||
default:
|
||||
params.Set(name, value)
|
||||
if canonical, ok := canonicalMySQLDriverParamName(name); ok {
|
||||
setMySQLDriverParam(params, canonical, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
|
||||
User: "scott",
|
||||
Password: "tiger",
|
||||
Database: "ORCLPDB1",
|
||||
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc",
|
||||
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&FAILOVER=3&unknown=bad",
|
||||
})
|
||||
|
||||
parsed, err := url.Parse(dsn)
|
||||
@@ -54,4 +54,13 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
|
||||
if got := query.Get("TRACE FILE"); got != "/tmp/go-ora.trc" {
|
||||
t.Fatalf("TRACE FILE = %q, want /tmp/go-ora.trc", got)
|
||||
}
|
||||
if got := query.Get("CONNECT TIMEOUT"); got != "10" {
|
||||
t.Fatalf("CONNECT TIMEOUT = %q, want 10", got)
|
||||
}
|
||||
if got := query.Get("FAILOVER"); got != "" {
|
||||
t.Fatalf("FAILOVER should be filtered because go-ora no longer supports it, got %q", got)
|
||||
}
|
||||
if got := query.Get("unknown"); got != "" {
|
||||
t.Fatalf("unknown should be filtered, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q.Set("PREFETCH_ROWS", "10000")
|
||||
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
|
||||
q.Set("LOB FETCH", "POST")
|
||||
mergeConnectionParamsFromConfig(q, config, "oracle")
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, oracleConnectionParamNames, "oracle")
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
u.RawQuery = encoded
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "opengauss")
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "opengauss")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
@@ -49,8 +49,8 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
|
||||
q.Set("encrypt", encrypt)
|
||||
q.Set("TrustServerCertificate", trustServerCertificate)
|
||||
mergeConnectionParamsFromConfig(q, config, "sqlserver")
|
||||
q.Set("trustservercertificate", trustServerCertificate)
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, sqlServerConnectionParamNames, "sqlserver")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
@@ -44,9 +44,7 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
|
||||
|
||||
netType := resolveTDengineNet(config)
|
||||
params := url.Values{}
|
||||
mergeConnectionParamsFromConfig(params, config, "taos", "taosws", "tdengine")
|
||||
params.Del("protocol")
|
||||
params.Del("skip_verify")
|
||||
mergeConnectionParamsFromConfigWithAllowlist(params, config, tdengineConnectionParamNames, "taos", "taosws", "tdengine")
|
||||
query := params.Encode()
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
|
||||
if query == "" {
|
||||
|
||||
@@ -43,7 +43,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "vastbase")
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "vastbase")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
Reference in New Issue
Block a user