🐛 fix(connection): 收敛数据库连接参数白名单

- MySQL 兼容 JDBC 参数映射并丢弃 allowPublicKeyRetrieval 等无效参数
- 为 PostgreSQL 系、SQL Server、Oracle、达梦、TDengine 接入驱动参数白名单
- 补充连接参数归一化、别名映射和未知参数过滤回归测试
This commit is contained in:
Syngnat
2026-05-13 17:51:02 +08:00
parent e6a1333f83
commit b2b1e6b944
15 changed files with 874 additions and 33 deletions

View File

@@ -0,0 +1,400 @@
package db
var postgresConnectionParamNames = newConnectionParamNameMap(
"host",
"hostaddr",
"port",
"user",
"password",
"passfile",
"dbname",
"database",
"options",
"application_name",
"fallback_application_name",
"sslmode",
"sslnegotiation",
"sslcert",
"sslkey",
"sslrootcert",
"sslsni",
"sslinline",
"krbsrvname",
"krbspn",
"connect_timeout",
"binary_parameters",
"disable_prepared_binary_result",
"client_encoding",
"datestyle",
"tz",
"geqo",
"target_session_attrs",
"load_balance_hosts",
"search_path",
"work_mem",
"statement_timeout",
"lock_timeout",
"idle_in_transaction_session_timeout",
"default_transaction_read_only",
"default_transaction_isolation",
"TimeZone",
)
var highGoConnectionParamNames = newConnectionParamNameMap(
"host",
"port",
"user",
"password",
"dbname",
"application_name",
"fallback_application_name",
"sslmode",
"sslcert",
"sslkey",
"sslrootcert",
"sslsni",
"sslinline",
"krbsrvname",
"krbspn",
"connect_timeout",
"binary_parameters",
"disable_prepared_binary_result",
"search_path",
"work_mem",
"statement_timeout",
"lock_timeout",
"idle_in_transaction_session_timeout",
"default_transaction_read_only",
"default_transaction_isolation",
"TimeZone",
)
var kingbaseConnectionParamNames = newConnectionParamNameMap(
"host",
"port",
"user",
"password",
"dbname",
"application_name",
"fallback_application_name",
"sslmode",
"sslcert",
"sslkey",
"sslrootcert",
"krbsrvname",
"krbspn",
"connect_timeout",
"binary_parameters",
"disable_prepared_binary_result",
"search_path",
"work_mem",
"statement_timeout",
"lock_timeout",
"idle_in_transaction_session_timeout",
"default_transaction_read_only",
"default_transaction_isolation",
"TimeZone",
)
var sqlServerConnectionParamNames = makeSQLServerConnectionParamNames()
func makeSQLServerConnectionParamNames() map[string]string {
names := newConnectionParamNameMap(
"database",
"encrypt",
"password",
"change password",
"user id",
"port",
"trustservercertificate",
"certificate",
"servercertificate",
"tlsmin",
"packet size",
"log",
"connection timeout",
"hostnameincertificate",
"keepalive",
"serverspn",
"workstation id",
"app name",
"applicationintent",
"failoverpartner",
"failoverport",
"disableretry",
"server",
"protocol",
"dial timeout",
"pipe",
"multisubnetfailover",
"notraceid",
"guid conversion",
"timezone",
"columnencryption",
)
addConnectionParamAlias(names, "application name", "app name")
addConnectionParamAlias(names, "data source", "server")
addConnectionParamAlias(names, "network address", "server")
addConnectionParamAlias(names, "address", "server")
addConnectionParamAlias(names, "addr", "server")
addConnectionParamAlias(names, "user", "user id")
addConnectionParamAlias(names, "uid", "user id")
addConnectionParamAlias(names, "pwd", "password")
addConnectionParamAlias(names, "initial catalog", "database")
addConnectionParamAlias(names, "column encryption setting", "columnencryption")
addConnectionParamAlias(names, "trust server certificate", "trustservercertificate")
addConnectionParamAlias(names, "multi subnet failover", "multisubnetfailover")
addConnectionParamAlias(names, "application intent", "applicationintent")
return names
}
var oracleConnectionParamNames = makeOracleConnectionParamNames()
func makeOracleConnectionParamNames() map[string]string {
names := newConnectionParamNameMap(
"CID",
"connStr",
"SERVER",
"SERVICE NAME",
"SID",
"INSTANCE NAME",
"WALLET",
"WALLET PASSWORD",
"AUTH TYPE",
"OS USER",
"OS PASS",
"OS PASSWORD",
"OS HASH",
"OS PASSHASH",
"OS PASSWORD HASH",
"DOMAIN",
"AUTH SERV",
"ENCRYPTION",
"DATA INTEGRITY",
"SSL",
"SSL VERIFY",
"DBA PRIVILEGE",
"TIMEOUT",
"READ TIMEOUT",
"SOCKET TIMEOUT",
"CONNECT TIMEOUT",
"CONNECTION TIMEOUT",
"TRACE FILE",
"TRACE DIR",
"TRACE FOLDER",
"TRACE DIRECTORY",
"USE_OOB",
"ENABLE_OOB",
"ENABLE URGENT DATA TRANSPORT",
"PREFETCH_ROWS",
"UNIX SOCKET",
"PROXY CLIENT NAME",
"LOB FETCH",
"LANGUAGE",
"TERRITORY",
"CHARSET",
"CLIENT CHARSET",
"PROGRAM",
"SERVER LOCATION",
)
addConnectionParamAlias(names, "SERVICE_NAME", "SERVICE NAME")
addConnectionParamAlias(names, "INSTANCE_NAME", "INSTANCE NAME")
addConnectionParamAlias(names, "WALLET_PASSWORD", "WALLET PASSWORD")
addConnectionParamAlias(names, "AUTH_TYPE", "AUTH TYPE")
addConnectionParamAlias(names, "AUTH_SERV", "AUTH SERV")
addConnectionParamAlias(names, "DATA_INTEGRITY", "DATA INTEGRITY")
addConnectionParamAlias(names, "SSL_VERIFY", "SSL VERIFY")
addConnectionParamAlias(names, "DBA_PRIVILEGE", "DBA PRIVILEGE")
addConnectionParamAlias(names, "READ_TIMEOUT", "READ TIMEOUT")
addConnectionParamAlias(names, "SOCKET_TIMEOUT", "SOCKET TIMEOUT")
addConnectionParamAlias(names, "CONNECT_TIMEOUT", "CONNECT TIMEOUT")
addConnectionParamAlias(names, "CONNECTION_TIMEOUT", "CONNECTION TIMEOUT")
addConnectionParamAlias(names, "TRACE_FILE", "TRACE FILE")
addConnectionParamAlias(names, "TRACE_DIR", "TRACE DIR")
addConnectionParamAlias(names, "TRACE_FOLDER", "TRACE FOLDER")
addConnectionParamAlias(names, "TRACE_DIRECTORY", "TRACE DIRECTORY")
addConnectionParamAlias(names, "UNIX_SOCKET", "UNIX SOCKET")
addConnectionParamAlias(names, "PROXY_CLIENT_NAME", "PROXY CLIENT NAME")
addConnectionParamAlias(names, "LOB_FETCH", "LOB FETCH")
addConnectionParamAlias(names, "CLIENT_CHARSET", "CLIENT CHARSET")
addConnectionParamAlias(names, "SERVER_LOCATION", "SERVER LOCATION")
return names
}
var damengConnectionParamNames = makeDamengConnectionParamNames()
func makeDamengConnectionParamNames() map[string]string {
names := newConnectionParamNameMap(
"timeZone",
"enRsCache",
"rsCacheSize",
"rsRefreshFreq",
"loginPrimary",
"loginMode",
"loginStatus",
"loginDscCtrl",
"switchTimes",
"switchInterval",
"epSelector",
"primaryKey",
"keywords",
"compress",
"compressId",
"loginEncrypt",
"communicationEncrypt",
"direct",
"dec2double",
"rwSeparate",
"rwPercent",
"rwAutoDistribute",
"compatibleMode",
"comOra",
"cipherPath",
"doSwitch",
"driverReconnect",
"cluster",
"language",
"dbAliveCheckFreq",
"rwStandbyRecoverTime",
"logLevel",
"logDir",
"logBufferPoolSize",
"logBufferSize",
"logFlusherQueueSize",
"logFlushFreq",
"statEnable",
"statDir",
"statFlushFreq",
"statHighFreqSqlCount",
"statSlowSqlCount",
"statSqlMaxCount",
"statSqlRemoveMode",
"addressRemap",
"userRemap",
"connectTimeout",
"loginCertificate",
"url",
"host",
"port",
"user",
"password",
"dialName",
"rwStandby",
"isCompress",
"rwHA",
"rwIgnoreSql",
"appName",
"osName",
"mppLocal",
"socketTimeout",
"sessionTimeout",
"continueBatchOnError",
"batchAllowMaxErrors",
"escapeProcess",
"autoCommit",
"maxRows",
"rowPrefetch",
"bufPrefetch",
"LobMode",
"StmtPoolSize",
"AlwayseAllowCommit",
"batchType",
"batchNotOnCall",
"isBdtaRS",
"clobAsString",
"sslCertPath",
"sslKeyPath",
"sslFilesPath",
"kerberosLoginConfPath",
"uKeyName",
"uKeyPin",
"columnNameUpperCase",
"columnNameCase",
"databaseProductName",
"osAuthType",
"schema",
"catalog",
"serverOption",
"clobToBytes",
"localTimezone",
"sessEncode",
"svcConfPath",
"confPath",
)
addConnectionParamAlias(names, "ADDRESS_REMAP", "addressRemap")
addConnectionParamAlias(names, "ALWAYS_ALLOW_COMMIT", "AlwayseAllowCommit")
addConnectionParamAlias(names, "APP_NAME", "appName")
addConnectionParamAlias(names, "AUTO_COMMIT", "autoCommit")
addConnectionParamAlias(names, "BATCH_ALLOW_MAX_ERRORS", "batchAllowMaxErrors")
addConnectionParamAlias(names, "BATCH_CONTINUE_ON_ERROR", "continueBatchOnError")
addConnectionParamAlias(names, "CONTINUE_BATCH_ON_ERROR", "continueBatchOnError")
addConnectionParamAlias(names, "BATCH_NOT_ON_CALL", "batchNotOnCall")
addConnectionParamAlias(names, "BATCH_TYPE", "batchType")
addConnectionParamAlias(names, "BUF_PREFETCH", "bufPrefetch")
addConnectionParamAlias(names, "CIPHER_PATH", "cipherPath")
addConnectionParamAlias(names, "COLUMN_NAME_UPPER_CASE", "columnNameUpperCase")
addConnectionParamAlias(names, "COLUMN_NAME_CASE", "columnNameCase")
addConnectionParamAlias(names, "COMPATIBLE_MODE", "compatibleMode")
addConnectionParamAlias(names, "COMPRESS_MSG", "compress")
addConnectionParamAlias(names, "COMPRESS_ID", "compressId")
addConnectionParamAlias(names, "CONNECT_TIMEOUT", "connectTimeout")
addConnectionParamAlias(names, "DO_SWITCH", "doSwitch")
addConnectionParamAlias(names, "AUTO_RECONNECT", "doSwitch")
addConnectionParamAlias(names, "ENABLE_RS_CACHE", "enRsCache")
addConnectionParamAlias(names, "EP_SELECTION", "epSelector")
addConnectionParamAlias(names, "ESCAPE_PROCESS", "escapeProcess")
addConnectionParamAlias(names, "IS_BDTA_RS", "isBdtaRS")
addConnectionParamAlias(names, "KEY_WORDS", "keywords")
addConnectionParamAlias(names, "LOB_MODE", "LobMode")
addConnectionParamAlias(names, "LOG_BUFFER_SIZE", "logBufferSize")
addConnectionParamAlias(names, "LOG_DIR", "logDir")
addConnectionParamAlias(names, "LOG_FLUSH_FREQ", "logFlushFreq")
addConnectionParamAlias(names, "LOG_FLUSHER_QUEUESIZE", "logFlusherQueueSize")
addConnectionParamAlias(names, "LOG_LEVEL", "logLevel")
addConnectionParamAlias(names, "LOGIN_DSC_CTRL", "loginDscCtrl")
addConnectionParamAlias(names, "LOGIN_ENCRYPT", "loginEncrypt")
addConnectionParamAlias(names, "LOGIN_MODE", "loginMode")
addConnectionParamAlias(names, "LOGIN_STATUS", "loginStatus")
addConnectionParamAlias(names, "MAX_ROWS", "maxRows")
addConnectionParamAlias(names, "MPP_LOCAL", "mppLocal")
addConnectionParamAlias(names, "OS_NAME", "osName")
addConnectionParamAlias(names, "RS_CACHE_SIZE", "rsCacheSize")
addConnectionParamAlias(names, "RS_REFRESH_FREQ", "rsRefreshFreq")
addConnectionParamAlias(names, "RW_HA", "rwHA")
addConnectionParamAlias(names, "RW_IGNORE_SQL", "rwIgnoreSql")
addConnectionParamAlias(names, "RW_PERCENT", "rwPercent")
addConnectionParamAlias(names, "RW_SEPARATE", "rwSeparate")
addConnectionParamAlias(names, "RW_STANDBY_RECOVER_TIME", "rwStandbyRecoverTime")
addConnectionParamAlias(names, "SESS_ENCODE", "sessEncode")
addConnectionParamAlias(names, "SESSION_TIMEOUT", "sessionTimeout")
addConnectionParamAlias(names, "SOCKET_TIMEOUT", "socketTimeout")
addConnectionParamAlias(names, "SSL_CERT_PATH", "sslCertPath")
addConnectionParamAlias(names, "SSL_FILES_PATH", "sslFilesPath")
addConnectionParamAlias(names, "SSL_KEY_PATH", "sslKeyPath")
addConnectionParamAlias(names, "STAT_DIR", "statDir")
addConnectionParamAlias(names, "STAT_ENABLE", "statEnable")
addConnectionParamAlias(names, "STAT_FLUSH_FREQ", "statFlushFreq")
addConnectionParamAlias(names, "STAT_HIGH_FREQ_SQL_COUNT", "statHighFreqSqlCount")
addConnectionParamAlias(names, "STAT_SLOW_SQL_COUNT", "statSlowSqlCount")
addConnectionParamAlias(names, "STAT_SQL_MAX_COUNT", "statSqlMaxCount")
addConnectionParamAlias(names, "STAT_SQL_REMOVE_MODE", "statSqlRemoveMode")
addConnectionParamAlias(names, "SWITCH_INTERVAL", "switchInterval")
addConnectionParamAlias(names, "SWITCH_TIME", "switchTimes")
addConnectionParamAlias(names, "SWITCH_TIMES", "switchTimes")
addConnectionParamAlias(names, "TIME_ZONE", "timeZone")
addConnectionParamAlias(names, "USER_REMAP", "userRemap")
addConnectionParamAlias(names, "SERVER_OPTION", "serverOption")
addConnectionParamAlias(names, "CLOB_TO_BYTES", "clobToBytes")
return names
}
var tdengineConnectionParamNames = newConnectionParamNameMap(
"interpolateParams",
"token",
"enableCompression",
"readTimeout",
"writeTimeout",
"timezone",
"bearerToken",
"totpCode",
)

View File

@@ -77,11 +77,69 @@ func mergeConnectionParamValues(params url.Values, values url.Values) {
}
}
func normalizeConnectionParamName(name string) string {
return strings.ToLower(strings.Join(strings.Fields(strings.TrimSpace(name)), " "))
}
func newConnectionParamNameMap(names ...string) map[string]string {
result := make(map[string]string, len(names))
for _, name := range names {
addConnectionParamName(result, name)
}
return result
}
func addConnectionParamName(names map[string]string, canonical string) {
key := normalizeConnectionParamName(canonical)
if key == "" {
return
}
names[key] = canonical
}
func addConnectionParamAlias(names map[string]string, alias string, canonical string) {
key := normalizeConnectionParamName(alias)
if key == "" {
return
}
if existing, ok := names[normalizeConnectionParamName(canonical)]; ok {
canonical = existing
}
names[key] = canonical
}
func mergeConnectionParamValuesWithAllowlist(params url.Values, values url.Values, canonicalNames map[string]string) {
if len(values) == 0 || len(canonicalNames) == 0 {
return
}
keys := make([]string, 0, len(values))
for key := range values {
if strings.TrimSpace(key) != "" {
keys = append(keys, key)
}
}
sort.Strings(keys)
for _, key := range keys {
canonical, ok := canonicalNames[normalizeConnectionParamName(key)]
if !ok {
continue
}
for _, value := range values[key] {
params.Set(canonical, value)
}
}
}
func mergeConnectionParamsFromConfig(params url.Values, config connection.ConnectionConfig, allowedSchemes ...string) {
mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, allowedSchemes...))
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
}
func mergeConnectionParamsFromConfigWithAllowlist(params url.Values, config connection.ConnectionConfig, canonicalNames map[string]string, allowedSchemes ...string) {
mergeConnectionParamValuesWithAllowlist(params, connectionParamsFromURI(config.URI, allowedSchemes...), canonicalNames)
mergeConnectionParamValuesWithAllowlist(params, connectionParamsFromText(config.ConnectionParams), canonicalNames)
}
func mergeConnectionParamsIntoRawURI(raw string, connectionParams string, allowedSchemes ...string) string {
text := strings.TrimSpace(raw)
if text == "" {

View File

@@ -0,0 +1,94 @@
package db
import (
"net/url"
"testing"
"GoNavi-Wails/internal/connection"
)
func TestMergeConnectionParamsFromConfigWithAllowlistCanonicalizesAndFilters(t *testing.T) {
t.Parallel()
params := url.Values{}
cfg := connection.ConnectionConfig{
URI: "postgres://u:p@db.local/app?application_name=from-uri&unknown_uri=bad",
ConnectionParams: "Application_Name=from-config&statement_timeout=3000&timezone=Asia%2FShanghai&unknown_config=bad",
}
mergeConnectionParamsFromConfigWithAllowlist(params, cfg, postgresConnectionParamNames, "postgres")
if got := params.Get("application_name"); got != "from-config" {
t.Fatalf("application_name = %q, want from-config", got)
}
if got := params.Get("statement_timeout"); got != "3000" {
t.Fatalf("statement_timeout = %q, want 3000", got)
}
if got := params.Get("TimeZone"); got != "Asia/Shanghai" {
t.Fatalf("TimeZone = %q, want Asia/Shanghai", got)
}
if got := params.Get("unknown_uri"); got != "" {
t.Fatalf("unknown_uri should be filtered, got %q", got)
}
if got := params.Get("unknown_config"); got != "" {
t.Fatalf("unknown_config should be filtered, got %q", got)
}
}
func TestSQLServerConnectionParamAllowlistMapsADOSynonyms(t *testing.T) {
t.Parallel()
params := url.Values{}
mergeConnectionParamValuesWithAllowlist(params, url.Values{
"Application Name": []string{"GoNavi"},
"Initial Catalog": []string{"appdb"},
"UID": []string{"sa"},
"Trust Server Certificate": []string{"true"},
"Column Encryption Setting": []string{"Enabled"},
"ignored": []string{"bad"},
}, sqlServerConnectionParamNames)
if got := params.Get("app name"); got != "GoNavi" {
t.Fatalf("app name = %q, want GoNavi", got)
}
if got := params.Get("database"); got != "appdb" {
t.Fatalf("database = %q, want appdb", got)
}
if got := params.Get("user id"); got != "sa" {
t.Fatalf("user id = %q, want sa", got)
}
if got := params.Get("trustservercertificate"); got != "true" {
t.Fatalf("trustservercertificate = %q, want true", got)
}
if got := params.Get("columnencryption"); got != "Enabled" {
t.Fatalf("columnencryption = %q, want Enabled", got)
}
if got := params.Get("ignored"); got != "" {
t.Fatalf("ignored should be filtered, got %q", got)
}
}
func TestDamengConnectionParamAllowlistMapsUppercaseAliases(t *testing.T) {
t.Parallel()
params := url.Values{}
mergeConnectionParamValuesWithAllowlist(params, url.Values{
"SSL_CERT_PATH": []string{"/cert.pem"},
"SSL_KEY_PATH": []string{"/key.pem"},
"CONNECT_TIMEOUT": []string{"5000"},
"unknown": []string{"bad"},
}, damengConnectionParamNames)
if got := params.Get("sslCertPath"); got != "/cert.pem" {
t.Fatalf("sslCertPath = %q, want /cert.pem", got)
}
if got := params.Get("sslKeyPath"); got != "/key.pem" {
t.Fatalf("sslKeyPath = %q, want /key.pem", got)
}
if got := params.Get("connectTimeout"); got != "5000" {
t.Fatalf("connectTimeout = %q, want 5000", got)
}
if got := params.Get("unknown"); got != "" {
t.Fatalf("unknown should be filtered, got %q", got)
}
}

View File

@@ -37,13 +37,13 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
}
if config.UseSSL {
if certPath := strings.TrimSpace(config.SSLCertPath); certPath != "" {
q.Set("SSL_CERT_PATH", certPath)
q.Set("sslCertPath", certPath)
}
if keyPath := strings.TrimSpace(config.SSLKeyPath); keyPath != "" {
q.Set("SSL_KEY_PATH", keyPath)
q.Set("sslKeyPath", keyPath)
}
}
mergeConnectionParamsFromConfig(q, config, "dm", "dameng")
mergeConnectionParamsFromConfigWithAllowlist(q, config, damengConnectionParamNames, "dm", "dameng")
// 当前达梦 Go 驱动使用字符串切分解析 DSN认证信息不会做 URL 反解码。
// 密码保持原样传入,避免 p%40ss 这类转义文本被当作真实密码登录。

View File

@@ -64,7 +64,7 @@ func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
User: "user",
Password: "pass",
Database: "db",
ConnectionParams: "application_name=GoNavi&connect_timeout=9",
ConnectionParams: "application_name=GoNavi&connect_timeout=9&statement_timeout=3000&allowPublicKeyRetrieval=true",
}
dsn := p.getDSN(cfg)
@@ -79,6 +79,12 @@ func TestPostgresDSN_MergesConnectionParams(t *testing.T) {
if got := query.Get("connect_timeout"); got != "9" {
t.Fatalf("connect_timeout = %q, want 9", got)
}
if got := query.Get("statement_timeout"); got != "3000" {
t.Fatalf("statement_timeout = %q, want 3000", got)
}
if got := query.Get("allowPublicKeyRetrieval"); got != "" {
t.Fatalf("unsupported postgres param should be filtered, got %q", got)
}
}
func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
@@ -184,11 +190,35 @@ func TestDamengDSN_AppendsSSLCertAndKeyParams(t *testing.T) {
}
dsn := d.getDSN(cfg)
if !strings.Contains(dsn, "SSL_CERT_PATH=") {
t.Fatalf("dsn 缺少 SSL_CERT_PATH 参数:%s", dsn)
if !strings.Contains(dsn, "sslCertPath=") {
t.Fatalf("dsn 缺少 sslCertPath 参数:%s", dsn)
}
if !strings.Contains(dsn, "SSL_KEY_PATH=") {
t.Fatalf("dsn 缺少 SSL_KEY_PATH 参数:%s", dsn)
if !strings.Contains(dsn, "sslKeyPath=") {
t.Fatalf("dsn 缺少 sslKeyPath 参数:%s", dsn)
}
}
func TestDamengDSN_FiltersUnsupportedConnectionParams(t *testing.T) {
d := &DamengDB{}
cfg := connection.ConnectionConfig{
Type: "dameng",
Host: "127.0.0.1",
Port: 5236,
User: "SYSDBA",
Password: "pass",
Database: "DBName",
ConnectionParams: "SSL_CERT_PATH=/cert.pem&CONNECT_TIMEOUT=5000&unknown=bad",
}
dsn := d.getDSN(cfg)
if !strings.Contains(dsn, "sslCertPath=%2Fcert.pem") {
t.Fatalf("dsn 缺少规范化 sslCertPath 参数:%s", dsn)
}
if !strings.Contains(dsn, "connectTimeout=5000") {
t.Fatalf("dsn 缺少规范化 connectTimeout 参数:%s", dsn)
}
if strings.Contains(dsn, "SSL_CERT_PATH") || strings.Contains(dsn, "unknown=bad") {
t.Fatalf("dsn 不应透传达梦未知或非规范参数:%s", dsn)
}
}
@@ -218,7 +248,7 @@ func TestKingbaseDSN_MergesConnectionParams(t *testing.T) {
User: "system",
Password: "pass",
Database: "TEST",
ConnectionParams: "application_name=GoNavi&connect_timeout=12",
ConnectionParams: "application_name=GoNavi&connect_timeout=12&statement_timeout=3000&unknown=bad",
}
dsn := k.getDSN(cfg)
@@ -228,6 +258,12 @@ func TestKingbaseDSN_MergesConnectionParams(t *testing.T) {
if !strings.Contains(dsn, "connect_timeout=12") {
t.Fatalf("dsn 缺少自定义 connect_timeout%s", dsn)
}
if !strings.Contains(dsn, "statement_timeout=3000") {
t.Fatalf("dsn 缺少允许的 runtime 参数:%s", dsn)
}
if strings.Contains(dsn, "unknown=bad") {
t.Fatalf("dsn 不应透传未知 Kingbase 参数:%s", dsn)
}
}
func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
@@ -275,13 +311,19 @@ func TestTDengineDSN_MergesConnectionParams(t *testing.T) {
User: "root",
Password: "taosdata",
Database: "power",
ConnectionParams: "timezone=Asia%2FShanghai&protocol=wss",
ConnectionParams: "timezone=Asia%2FShanghai&protocol=wss&readTimeout=10s&unknown=bad",
}
dsn := td.getDSN(cfg)
if !strings.Contains(dsn, "?timezone=Asia%2FShanghai") {
if !strings.Contains(dsn, "timezone=Asia%2FShanghai") {
t.Fatalf("tdengine dsn 缺少自定义参数或错误透传 protocol%s", dsn)
}
if !strings.Contains(dsn, "readTimeout=10s") {
t.Fatalf("tdengine dsn 缺少 readTimeout 参数:%s", dsn)
}
if strings.Contains(dsn, "protocol=wss") || strings.Contains(dsn, "unknown=bad") {
t.Fatalf("tdengine dsn 不应透传协议控制项或未知参数:%s", dsn)
}
}
func TestSQLServerDSN_EncryptMapping(t *testing.T) {
@@ -315,7 +357,7 @@ func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
User: "sa",
Password: "pass",
Database: "master",
ConnectionParams: "app name=GoNavi&packet size=32767",
ConnectionParams: "Application Name=GoNavi&Initial Catalog=appdb&packet size=32767&unknown=bad",
}
dsn := s.getDSN(cfg)
@@ -327,9 +369,15 @@ func TestSQLServerDSN_MergesConnectionParams(t *testing.T) {
if got := query.Get("app name"); got != "GoNavi" {
t.Fatalf("app name = %q, want GoNavi", got)
}
if got := query.Get("database"); got != "appdb" {
t.Fatalf("database = %q, want appdb", got)
}
if got := query.Get("packet size"); got != "32767" {
t.Fatalf("packet size = %q, want 32767", got)
}
if got := query.Get("unknown"); got != "" {
t.Fatalf("unknown should be filtered, got %q", got)
}
}
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {

View File

@@ -44,7 +44,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "highgo")
mergeConnectionParamsFromConfigWithAllowlist(q, config, highGoConnectionParamNames, "postgres", "postgresql", "highgo")
u.RawQuery = q.Encode()
return u.String()

View File

@@ -71,7 +71,7 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
params.Set("dbname", config.Database)
params.Set("sslmode", resolvePostgresSSLMode(config))
params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(params, config, "kingbase")
mergeConnectionParamsFromConfigWithAllowlist(params, config, kingbaseConnectionParamNames, "kingbase")
preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"}
seen := make(map[string]struct{}, len(params))

View File

@@ -64,7 +64,7 @@ func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T
User: "root",
Database: "app",
ConnectionParams: "useUnicode=true&characterEncoding=utf8&autoReconnect=true&" +
"useSSL=false&verifyServerCertificate=false&useOldAliasMetadataBehavior=true",
"allowPublicKeyRetrieval=true&useSSL=false&verifyServerCertificate=false&useOldAliasMetadataBehavior=true",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
@@ -81,6 +81,7 @@ func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T
"useUnicode",
"characterEncoding",
"autoReconnect",
"allowPublicKeyRetrieval",
"useSSL",
"verifyServerCertificate",
"useOldAliasMetadataBehavior",
@@ -122,6 +123,153 @@ func TestMySQLDSN_MapsJDBCUTF8EncodingToMySQLCharset(t *testing.T) {
}
}
func TestMySQLDSN_DropsJDBCAllowPublicKeyRetrievalParam(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "db.local",
Port: 3306,
User: "root",
Database: "app",
URI: "jdbc:mysql://db.local:3306/app?allowPublicKeyRetrieval=true&useSSL=false",
ConnectionParams: "allowPublicKeyRetrieval=true&readtimeout=10&writetimeout=11",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if _, exists := query["allowPublicKeyRetrieval"]; exists {
t.Fatalf("JDBC allowPublicKeyRetrieval should not be passed to Go MySQL driver: %v", query)
}
if got := query.Get("tls"); got != "false" {
t.Fatalf("useSSL=false should still map to tls=false, got=%q", got)
}
if got := query.Get("readTimeout"); got != "10s" {
t.Fatalf("readtimeout should canonicalize to readTimeout duration, got=%q", got)
}
if got := query.Get("writeTimeout"); got != "11s" {
t.Fatalf("writetimeout should canonicalize to writeTimeout duration, got=%q", got)
}
}
func TestMySQLDSN_PreservesSupportedGoDriverParamsAndDropsUnknownParams(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "db.local",
Port: 3306,
User: "root",
Database: "app",
ConnectionParams: strings.Join([]string{
"allowAllFiles=true",
"allowCleartextPasswords=true",
"allowFallbackToPlaintext=true",
"allowNativePasswords=false",
"allowOldPasswords=true",
"checkConnLiveness=false",
"clientFoundRows=true",
"charset=latin1",
"collation=utf8mb4_unicode_ci",
"columnsWithAlias=true",
"compress=true",
"connectionAttributes=program_name:GoNavi",
"interpolateParams=true",
"loc=UTC",
"maxAllowedPacket=1048576",
"multiStatements=false",
"parseTime=false",
"readtimeout=7",
"rejectReadOnly=true",
"serverPubKey=testKey",
"timeTruncate=2",
"timeout=8",
"tls=preferred",
"writetimeout=9",
"strict=true",
"unsupportedJdbcParam=true",
}, "&"),
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
want := map[string]string{
"allowAllFiles": "true",
"allowCleartextPasswords": "true",
"allowFallbackToPlaintext": "true",
"allowNativePasswords": "false",
"allowOldPasswords": "true",
"checkConnLiveness": "false",
"clientFoundRows": "true",
"charset": "latin1",
"collation": "utf8mb4_unicode_ci",
"columnsWithAlias": "true",
"compress": "true",
"connectionAttributes": "program_name:GoNavi",
"interpolateParams": "true",
"loc": "UTC",
"maxAllowedPacket": "1048576",
"multiStatements": "false",
"parseTime": "false",
"readTimeout": "7s",
"rejectReadOnly": "true",
"serverPubKey": "testKey",
"timeTruncate": "2s",
"timeout": "8s",
"tls": "preferred",
"writeTimeout": "9s",
}
for key, value := range want {
if got := query.Get(key); got != value {
t.Fatalf("%s should be %q, got %q; query=%v", key, value, got, query)
}
}
for _, forbidden := range []string{"strict", "unsupportedJdbcParam"} {
if _, exists := query[forbidden]; exists {
t.Fatalf("unsupported parameter %s should not be passed to Go MySQL driver: %v", forbidden, query)
}
}
}
func TestMySQLDSN_MapsAdditionalJDBCAliases(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "db.local",
Port: 3306,
User: "root",
Database: "app",
ConnectionParams: strings.Join([]string{
"sslMode=required",
"allowMultiQueries=false",
"useCompression=true",
"connectionCollation=utf8mb4_bin",
}, "&"),
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if got := query.Get("tls"); got != "true" {
t.Fatalf("sslMode=required should map to tls=true, got=%q", got)
}
if got := query.Get("multiStatements"); got != "false" {
t.Fatalf("allowMultiQueries=false should map to multiStatements=false, got=%q", got)
}
if got := query.Get("compress"); got != "true" {
t.Fatalf("useCompression=true should map to compress=true, got=%q", got)
}
if got := query.Get("collation"); got != "utf8mb4_bin" {
t.Fatalf("connectionCollation should map to collation, got=%q", got)
}
}
func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) {
t.Parallel()

View File

@@ -99,6 +99,73 @@ func normalizeMySQLServerTimezoneParam(raw string) (string, bool) {
return "", false
}
var mysqlSupportedDriverParamNames = map[string]string{
"allowallfiles": "allowAllFiles",
"allowcleartextpasswords": "allowCleartextPasswords",
"allowfallbacktoplaintext": "allowFallbackToPlaintext",
"allownativepasswords": "allowNativePasswords",
"allowoldpasswords": "allowOldPasswords",
"checkconnliveness": "checkConnLiveness",
"clientfoundrows": "clientFoundRows",
"charset": "charset",
"collation": "collation",
"columnswithalias": "columnsWithAlias",
"compress": "compress",
"connectionattributes": "connectionAttributes",
"interpolateparams": "interpolateParams",
"loc": "loc",
"maxallowedpacket": "maxAllowedPacket",
"multistatements": "multiStatements",
"parsetime": "parseTime",
"readtimeout": "readTimeout",
"rejectreadonly": "rejectReadOnly",
"serverpubkey": "serverPubKey",
"timetruncate": "timeTruncate",
"timeout": "timeout",
"tls": "tls",
"writetimeout": "writeTimeout",
}
var mysqlBoolDriverParamNames = map[string]struct{}{
"allowAllFiles": {},
"allowCleartextPasswords": {},
"allowFallbackToPlaintext": {},
"allowNativePasswords": {},
"allowOldPasswords": {},
"checkConnLiveness": {},
"clientFoundRows": {},
"columnsWithAlias": {},
"compress": {},
"interpolateParams": {},
"multiStatements": {},
"parseTime": {},
"rejectReadOnly": {},
}
func canonicalMySQLDriverParamName(name string) (string, bool) {
canonical, ok := mysqlSupportedDriverParamNames[strings.ToLower(strings.TrimSpace(name))]
return canonical, ok
}
func setMySQLDriverParam(params url.Values, name string, value string) {
switch name {
case "charset":
if charset := normalizeMySQLCharsetParam(value); charset != "" {
params.Set("charset", charset)
}
case "timeout", "readTimeout", "writeTimeout", "timeTruncate":
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
default:
if _, ok := mysqlBoolDriverParamNames[name]; ok {
if enabled, ok := parseMySQLBoolParam(value); ok {
params.Set(name, strconv.FormatBool(enabled))
return
}
}
params.Set(name, value)
}
}
func mergeMySQLConnectionParam(params url.Values, key string, value string) {
name := strings.TrimSpace(key)
if name == "" {
@@ -108,12 +175,7 @@ func mergeMySQLConnectionParam(params url.Values, key string, value string) {
switch lowerName {
case "topology":
return
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior":
return
case "charset":
if charset := normalizeMySQLCharsetParam(value); charset != "" {
params.Set("charset", charset)
}
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior", "allowpublickeyretrieval":
return
case "characterencoding":
if charset := normalizeMySQLCharsetParam(value); charset != "" {
@@ -144,17 +206,41 @@ func mergeMySQLConnectionParam(params url.Values, key string, value string) {
params.Set("tls", "skip-verify")
}
return
case "sslmode":
switch normalizeSSLModeValue(value) {
case sslModeDisable:
params.Set("tls", "false")
case sslModeRequired:
params.Set("tls", "true")
case sslModeSkipVerify:
params.Set("tls", "skip-verify")
default:
params.Set("tls", "preferred")
}
return
case "connecttimeout":
params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond))
return
case "sockettimeout":
params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond))
return
case "timeout", "readtimeout", "writetimeout":
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
case "allowmultiqueries":
if enabled, ok := parseMySQLBoolParam(value); ok {
params.Set("multiStatements", strconv.FormatBool(enabled))
}
return
case "usecompression":
if enabled, ok := parseMySQLBoolParam(value); ok {
params.Set("compress", strconv.FormatBool(enabled))
}
return
case "connectioncollation":
params.Set("collation", value)
return
default:
params.Set(name, value)
if canonical, ok := canonicalMySQLDriverParamName(name); ok {
setMySQLDriverParam(params, canonical, value)
}
}
}

View File

@@ -40,7 +40,7 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
User: "scott",
Password: "tiger",
Database: "ORCLPDB1",
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc",
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&FAILOVER=3&unknown=bad",
})
parsed, err := url.Parse(dsn)
@@ -54,4 +54,13 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
if got := query.Get("TRACE FILE"); got != "/tmp/go-ora.trc" {
t.Fatalf("TRACE FILE = %q, want /tmp/go-ora.trc", got)
}
if got := query.Get("CONNECT TIMEOUT"); got != "10" {
t.Fatalf("CONNECT TIMEOUT = %q, want 10", got)
}
if got := query.Get("FAILOVER"); got != "" {
t.Fatalf("FAILOVER should be filtered because go-ora no longer supports it, got %q", got)
}
if got := query.Get("unknown"); got != "" {
t.Fatalf("unknown should be filtered, got %q", got)
}
}

View File

@@ -48,7 +48,7 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
q.Set("PREFETCH_ROWS", "10000")
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
q.Set("LOB FETCH", "POST")
mergeConnectionParamsFromConfig(q, config, "oracle")
mergeConnectionParamsFromConfigWithAllowlist(q, config, oracleConnectionParamNames, "oracle")
if encoded := q.Encode(); encoded != "" {
u.RawQuery = encoded
}

View File

@@ -64,7 +64,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "opengauss")
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "opengauss")
u.RawQuery = q.Encode()
return u.String()

View File

@@ -49,8 +49,8 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
q.Set("encrypt", encrypt)
q.Set("TrustServerCertificate", trustServerCertificate)
mergeConnectionParamsFromConfig(q, config, "sqlserver")
q.Set("trustservercertificate", trustServerCertificate)
mergeConnectionParamsFromConfigWithAllowlist(q, config, sqlServerConnectionParamNames, "sqlserver")
u.RawQuery = q.Encode()
return u.String()

View File

@@ -44,9 +44,7 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
netType := resolveTDengineNet(config)
params := url.Values{}
mergeConnectionParamsFromConfig(params, config, "taos", "taosws", "tdengine")
params.Del("protocol")
params.Del("skip_verify")
mergeConnectionParamsFromConfigWithAllowlist(params, config, tdengineConnectionParamNames, "taos", "taosws", "tdengine")
query := params.Encode()
dsn := fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
if query == "" {

View File

@@ -43,7 +43,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql", "vastbase")
mergeConnectionParamsFromConfigWithAllowlist(q, config, postgresConnectionParamNames, "postgres", "postgresql", "vastbase")
u.RawQuery = q.Encode()
return u.String()