🐛 fix(mysql): 修复 GDB 连接参数不兼容导致的握手失败

- 优化 MySQL 兼容 DSN 默认参数
- 在连接验证阶段增加 multiStatements 兼容回退
- 补充相关单元测试覆盖

Refs #543
This commit is contained in:
Syngnat
2026-06-07 14:50:42 +08:00
parent 6932abe674
commit dda8bbb6e3
2 changed files with 226 additions and 22 deletions

View File

@@ -54,6 +54,33 @@ func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) {
}
}
func TestMySQLDSN_EmptyDatabaseUsesCompatibilityDefaults(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "gdb.local",
Port: 1523,
User: "glzc",
Password: "secret",
Timeout: 30,
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
if !strings.Contains(dsn, "@tcp(gdb.local:1523)/?") {
t.Fatalf("empty database should still keep DSN slash separator, got=%q", dsn)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if got := query.Get("charset"); got != "utf8mb4,utf8" {
t.Fatalf("default charset should fall back from utf8mb4 to utf8, got=%q", got)
}
if got := query.Get("multiStatements"); got != "true" {
t.Fatalf("default multiStatements should remain enabled, got=%q", got)
}
}
func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T) {
t.Parallel()
@@ -294,6 +321,61 @@ func TestMySQLDSN_MapsAllowMultiQueriesTrueWithoutLeakingKey(t *testing.T) {
}
}
func TestBuildMySQLCompatibleConnectPlans_AddsHandshakeFallbackWhenMultiStatementsImplicit(t *testing.T) {
t.Parallel()
plans, err := buildMySQLCompatibleConnectPlans(connection.ConnectionConfig{
Host: "gdb.local",
Port: 1523,
User: "glzc",
Timeout: 30,
}, "tcp", "gdb.local:1523", "")
if err != nil {
t.Fatalf("buildMySQLCompatibleConnectPlans failed: %v", err)
}
if len(plans) != 2 {
t.Fatalf("expected default plan plus compatibility fallback, got %d", len(plans))
}
defaultQuery := parseMySQLDSNQueryForTest(t, plans[0].dsn)
if got := defaultQuery.Get("multiStatements"); got != "true" {
t.Fatalf("default plan should keep multiStatements enabled, got=%q", got)
}
if got := defaultQuery.Get("charset"); got != "utf8mb4,utf8" {
t.Fatalf("default plan should use utf8 fallback charset, got=%q", got)
}
fallbackQuery := parseMySQLDSNQueryForTest(t, plans[1].dsn)
if got := fallbackQuery.Get("multiStatements"); got != "false" {
t.Fatalf("fallback plan should disable multiStatements, got=%q", got)
}
if got := fallbackQuery.Get("charset"); got != "utf8mb4,utf8" {
t.Fatalf("fallback plan should preserve charset fallback, got=%q", got)
}
}
func TestBuildMySQLCompatibleConnectPlans_RespectsExplicitMultiStatementsChoice(t *testing.T) {
t.Parallel()
plans, err := buildMySQLCompatibleConnectPlans(connection.ConnectionConfig{
Host: "db.local",
Port: 3306,
User: "root",
ConnectionParams: "allowMultiQueries=false",
}, "tcp", "db.local:3306", "app")
if err != nil {
t.Fatalf("buildMySQLCompatibleConnectPlans failed: %v", err)
}
if len(plans) != 1 {
t.Fatalf("explicit multiStatements choice should skip compatibility fallback, got %d plans", len(plans))
}
query := parseMySQLDSNQueryForTest(t, plans[0].dsn)
if got := query.Get("multiStatements"); got != "false" {
t.Fatalf("explicit allowMultiQueries=false should be preserved, got=%q", got)
}
}
func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) {
t.Parallel()

View File

@@ -278,6 +278,53 @@ func mergeMySQLConnectionParams(params url.Values, values url.Values) {
}
}
type mySQLCompatibleDSNOptions struct {
defaultCharset string
defaultMultiStatements *bool
}
type mySQLCompatibleConnectPlan struct {
label string
dsn string
}
const (
mySQLCompatPlanDefaultLabel = "默认兼容参数"
mySQLCompatPlanDisableMultiStatementsLabel = "禁用 multiStatements 兼容重试"
)
func hasMySQLConnectionParam(config connection.ConnectionConfig, names ...string) bool {
if len(names) == 0 {
return false
}
targets := make(map[string]struct{}, len(names))
for _, name := range names {
normalized := strings.ToLower(strings.TrimSpace(name))
if normalized == "" {
continue
}
targets[normalized] = struct{}{}
}
if len(targets) == 0 {
return false
}
hasMatchingKey := func(values url.Values) bool {
for key := range values {
if _, ok := targets[strings.ToLower(strings.TrimSpace(key))]; ok {
return true
}
}
return false
}
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "mariadb", "doris", "diros", "oceanbase", "starrocks"); ok && hasMatchingKey(parsed.Query()) {
return true
}
return hasMatchingKey(mysqlConnectionParamsFromText(config.ConnectionParams))
}
func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, error) {
mode := resolveMySQLTLSMode(config)
if mode == "false" || !hasTLSCertificatePaths(config) {
@@ -297,14 +344,18 @@ func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, err
return name, normalizeSSLModeValue(config.SSLMode) == sslModePreferred, nil
}
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
func buildMySQLCompatibleDSNWithOptions(config connection.ConnectionConfig, protocol, address, database string, options mySQLCompatibleDSNOptions) (string, error) {
timeout := getConnectTimeoutSeconds(config)
tlsMode, allowFallbackToPlaintext, err := resolveMySQLTLSParam(config)
if err != nil {
return "", err
}
params := url.Values{}
params.Set("charset", "utf8mb4")
defaultCharset := strings.TrimSpace(options.defaultCharset)
if defaultCharset == "" {
defaultCharset = "utf8mb4,utf8"
}
params.Set("charset", defaultCharset)
params.Set("parseTime", "True")
params.Set("loc", "Local")
params.Set("timeout", fmt.Sprintf("%ds", timeout))
@@ -312,7 +363,11 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre
if allowFallbackToPlaintext {
params.Set("allowFallbackToPlaintext", "true")
}
params.Set("multiStatements", "true")
defaultMultiStatements := true
if options.defaultMultiStatements != nil {
defaultMultiStatements = *options.defaultMultiStatements
}
params.Set("multiStatements", strconv.FormatBool(defaultMultiStatements))
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok {
mergeMySQLConnectionParams(params, parsed.Query())
}
@@ -323,6 +378,46 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre
), nil
}
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
defaultMultiStatements := true
return buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{
defaultCharset: "utf8mb4,utf8",
defaultMultiStatements: &defaultMultiStatements,
})
}
func buildMySQLCompatibleConnectPlans(config connection.ConnectionConfig, protocol, address, database string) ([]mySQLCompatibleConnectPlan, error) {
defaultDSN, err := buildMySQLCompatibleDSN(config, protocol, address, database)
if err != nil {
return nil, err
}
plans := []mySQLCompatibleConnectPlan{{
label: mySQLCompatPlanDefaultLabel,
dsn: defaultDSN,
}}
if hasMySQLConnectionParam(config, "multiStatements", "allowMultiQueries") {
return plans, nil
}
disabled := false
fallbackDSN, err := buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{
defaultCharset: "utf8mb4,utf8",
defaultMultiStatements: &disabled,
})
if err != nil {
return nil, err
}
if fallbackDSN == defaultDSN {
return plans, nil
}
return append(plans, mySQLCompatibleConnectPlan{
label: mySQLCompatPlanDisableMultiStatementsLabel,
dsn: fallbackDSN,
}), nil
}
func normalizeMySQLRawDSNCompatibilityParams(raw string) string {
text := strings.TrimSpace(raw)
queryIndex := strings.Index(text, "?")
@@ -582,20 +677,27 @@ func collectMySQLAddresses(config connection.ConnectionConfig) []string {
return result
}
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
database := config.Database
func (m *MySQLDB) resolveProtocolAndAddress(config connection.ConnectionConfig) (string, string, error) {
protocol := "tcp"
address := normalizeMySQLAddress(config.Host, config.Port)
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err != nil {
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
return "", "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database)
return protocol, address, nil
}
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
protocol, address, err := m.resolveProtocolAndAddress(config)
if err != nil {
return "", err
}
return buildMySQLCompatibleDSN(config, protocol, address, config.Database)
}
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
@@ -633,30 +735,50 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
candidateConfig.Port = port
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
dsn, err := m.getDSN(candidateConfig)
protocol, address, err := m.resolveProtocolAndAddress(candidateConfig)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
continue
}
db, err := sql.Open("mysql", dsn)
plans, err := buildMySQLCompatibleConnectPlans(candidateConfig, protocol, address, candidateConfig.Database)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
continue
}
timeout := getConnectTimeout(candidateConfig)
ctx, cancel := utils.ContextWithTimeout(timeout)
pingErr := db.PingContext(ctx)
cancel()
if pingErr != nil {
_ = db.Close()
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
continue
}
for _, plan := range plans {
db, err := sql.Open("mysql", plan.dsn)
if err != nil {
if len(plans) > 1 || plan.label != mySQLCompatPlanDefaultLabel {
errorDetails = append(errorDetails, fmt.Sprintf("%s [%s] 打开失败: %v", address, plan.label, err))
} else {
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
}
continue
}
m.conn = db
m.pingTimeout = timeout
return nil
timeout := getConnectTimeout(candidateConfig)
ctx, cancel := utils.ContextWithTimeout(timeout)
pingErr := db.PingContext(ctx)
cancel()
if pingErr != nil {
_ = db.Close()
if len(plans) > 1 || plan.label != mySQLCompatPlanDefaultLabel {
errorDetails = append(errorDetails, fmt.Sprintf("%s [%s] 验证失败: %v", address, plan.label, pingErr))
} else {
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
}
continue
}
if plan.label != mySQLCompatPlanDefaultLabel {
logger.Warnf("MySQL 兼容回退生效:地址=%s 模式=%s", address, plan.label)
}
m.conn = db
m.pingTimeout = timeout
return nil
}
}
if len(errorDetails) == 0 {