From dda8bbb6e3a361ebc10c931206944a6bb151c6e6 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sun, 7 Jun 2026 14:50:42 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(mysql):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=20GDB=20=E8=BF=9E=E6=8E=A5=E5=8F=82=E6=95=B0=E4=B8=8D=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E5=AF=BC=E8=87=B4=E7=9A=84=E6=8F=A1=E6=89=8B=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化 MySQL 兼容 DSN 默认参数 - 在连接验证阶段增加 multiStatements 兼容回退 - 补充相关单元测试覆盖 Refs #543 --- internal/db/mysql_connection_params_test.go | 82 ++++++++++ internal/db/mysql_impl.go | 166 +++++++++++++++++--- 2 files changed, 226 insertions(+), 22 deletions(-) diff --git a/internal/db/mysql_connection_params_test.go b/internal/db/mysql_connection_params_test.go index 97fdd97..41453e4 100644 --- a/internal/db/mysql_connection_params_test.go +++ b/internal/db/mysql_connection_params_test.go @@ -54,6 +54,33 @@ func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) { } } +func TestMySQLDSN_EmptyDatabaseUsesCompatibilityDefaults(t *testing.T) { + t.Parallel() + + m := &MySQLDB{} + dsn, err := m.getDSN(connection.ConnectionConfig{ + Host: "gdb.local", + Port: 1523, + User: "glzc", + Password: "secret", + Timeout: 30, + }) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + if !strings.Contains(dsn, "@tcp(gdb.local:1523)/?") { + t.Fatalf("empty database should still keep DSN slash separator, got=%q", dsn) + } + + query := parseMySQLDSNQueryForTest(t, dsn) + if got := query.Get("charset"); got != "utf8mb4,utf8" { + t.Fatalf("default charset should fall back from utf8mb4 to utf8, got=%q", got) + } + if got := query.Get("multiStatements"); got != "true" { + t.Fatalf("default multiStatements should remain enabled, got=%q", got) + } +} + func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T) { t.Parallel() @@ -294,6 +321,61 @@ func TestMySQLDSN_MapsAllowMultiQueriesTrueWithoutLeakingKey(t *testing.T) { } } +func TestBuildMySQLCompatibleConnectPlans_AddsHandshakeFallbackWhenMultiStatementsImplicit(t *testing.T) { + t.Parallel() + + plans, err := buildMySQLCompatibleConnectPlans(connection.ConnectionConfig{ + Host: "gdb.local", + Port: 1523, + User: "glzc", + Timeout: 30, + }, "tcp", "gdb.local:1523", "") + if err != nil { + t.Fatalf("buildMySQLCompatibleConnectPlans failed: %v", err) + } + if len(plans) != 2 { + t.Fatalf("expected default plan plus compatibility fallback, got %d", len(plans)) + } + + defaultQuery := parseMySQLDSNQueryForTest(t, plans[0].dsn) + if got := defaultQuery.Get("multiStatements"); got != "true" { + t.Fatalf("default plan should keep multiStatements enabled, got=%q", got) + } + if got := defaultQuery.Get("charset"); got != "utf8mb4,utf8" { + t.Fatalf("default plan should use utf8 fallback charset, got=%q", got) + } + + fallbackQuery := parseMySQLDSNQueryForTest(t, plans[1].dsn) + if got := fallbackQuery.Get("multiStatements"); got != "false" { + t.Fatalf("fallback plan should disable multiStatements, got=%q", got) + } + if got := fallbackQuery.Get("charset"); got != "utf8mb4,utf8" { + t.Fatalf("fallback plan should preserve charset fallback, got=%q", got) + } +} + +func TestBuildMySQLCompatibleConnectPlans_RespectsExplicitMultiStatementsChoice(t *testing.T) { + t.Parallel() + + plans, err := buildMySQLCompatibleConnectPlans(connection.ConnectionConfig{ + Host: "db.local", + Port: 3306, + User: "root", + ConnectionParams: "allowMultiQueries=false", + }, "tcp", "db.local:3306", "app") + if err != nil { + t.Fatalf("buildMySQLCompatibleConnectPlans failed: %v", err) + } + if len(plans) != 1 { + t.Fatalf("explicit multiStatements choice should skip compatibility fallback, got %d plans", len(plans)) + } + + query := parseMySQLDSNQueryForTest(t, plans[0].dsn) + if got := query.Get("multiStatements"); got != "false" { + t.Fatalf("explicit allowMultiQueries=false should be preserved, got=%q", got) + } +} + func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) { t.Parallel() diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index ccbca70..d280f41 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -278,6 +278,53 @@ func mergeMySQLConnectionParams(params url.Values, values url.Values) { } } +type mySQLCompatibleDSNOptions struct { + defaultCharset string + defaultMultiStatements *bool +} + +type mySQLCompatibleConnectPlan struct { + label string + dsn string +} + +const ( + mySQLCompatPlanDefaultLabel = "默认兼容参数" + mySQLCompatPlanDisableMultiStatementsLabel = "禁用 multiStatements 兼容重试" +) + +func hasMySQLConnectionParam(config connection.ConnectionConfig, names ...string) bool { + if len(names) == 0 { + return false + } + + targets := make(map[string]struct{}, len(names)) + for _, name := range names { + normalized := strings.ToLower(strings.TrimSpace(name)) + if normalized == "" { + continue + } + targets[normalized] = struct{}{} + } + if len(targets) == 0 { + return false + } + + hasMatchingKey := func(values url.Values) bool { + for key := range values { + if _, ok := targets[strings.ToLower(strings.TrimSpace(key))]; ok { + return true + } + } + return false + } + + if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "mariadb", "doris", "diros", "oceanbase", "starrocks"); ok && hasMatchingKey(parsed.Query()) { + return true + } + return hasMatchingKey(mysqlConnectionParamsFromText(config.ConnectionParams)) +} + func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, error) { mode := resolveMySQLTLSMode(config) if mode == "false" || !hasTLSCertificatePaths(config) { @@ -297,14 +344,18 @@ func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, err return name, normalizeSSLModeValue(config.SSLMode) == sslModePreferred, nil } -func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) { +func buildMySQLCompatibleDSNWithOptions(config connection.ConnectionConfig, protocol, address, database string, options mySQLCompatibleDSNOptions) (string, error) { timeout := getConnectTimeoutSeconds(config) tlsMode, allowFallbackToPlaintext, err := resolveMySQLTLSParam(config) if err != nil { return "", err } params := url.Values{} - params.Set("charset", "utf8mb4") + defaultCharset := strings.TrimSpace(options.defaultCharset) + if defaultCharset == "" { + defaultCharset = "utf8mb4,utf8" + } + params.Set("charset", defaultCharset) params.Set("parseTime", "True") params.Set("loc", "Local") params.Set("timeout", fmt.Sprintf("%ds", timeout)) @@ -312,7 +363,11 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre if allowFallbackToPlaintext { params.Set("allowFallbackToPlaintext", "true") } - params.Set("multiStatements", "true") + defaultMultiStatements := true + if options.defaultMultiStatements != nil { + defaultMultiStatements = *options.defaultMultiStatements + } + params.Set("multiStatements", strconv.FormatBool(defaultMultiStatements)) if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok { mergeMySQLConnectionParams(params, parsed.Query()) } @@ -323,6 +378,46 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre ), nil } +func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) { + defaultMultiStatements := true + return buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{ + defaultCharset: "utf8mb4,utf8", + defaultMultiStatements: &defaultMultiStatements, + }) +} + +func buildMySQLCompatibleConnectPlans(config connection.ConnectionConfig, protocol, address, database string) ([]mySQLCompatibleConnectPlan, error) { + defaultDSN, err := buildMySQLCompatibleDSN(config, protocol, address, database) + if err != nil { + return nil, err + } + plans := []mySQLCompatibleConnectPlan{{ + label: mySQLCompatPlanDefaultLabel, + dsn: defaultDSN, + }} + + if hasMySQLConnectionParam(config, "multiStatements", "allowMultiQueries") { + return plans, nil + } + + disabled := false + fallbackDSN, err := buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{ + defaultCharset: "utf8mb4,utf8", + defaultMultiStatements: &disabled, + }) + if err != nil { + return nil, err + } + if fallbackDSN == defaultDSN { + return plans, nil + } + + return append(plans, mySQLCompatibleConnectPlan{ + label: mySQLCompatPlanDisableMultiStatementsLabel, + dsn: fallbackDSN, + }), nil +} + func normalizeMySQLRawDSNCompatibilityParams(raw string) string { text := strings.TrimSpace(raw) queryIndex := strings.Index(text, "?") @@ -582,20 +677,27 @@ func collectMySQLAddresses(config connection.ConnectionConfig) []string { return result } -func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) { - database := config.Database +func (m *MySQLDB) resolveProtocolAndAddress(config connection.ConnectionConfig) (string, string, error) { protocol := "tcp" address := normalizeMySQLAddress(config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) if err != nil { - return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) + return "", "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } protocol = netName } - return buildMySQLCompatibleDSN(config, protocol, address, database) + return protocol, address, nil +} + +func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) { + protocol, address, err := m.resolveProtocolAndAddress(config) + if err != nil { + return "", err + } + return buildMySQLCompatibleDSN(config, protocol, address, config.Database) } func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { @@ -633,30 +735,50 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error { candidateConfig.Port = port candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index) - dsn, err := m.getDSN(candidateConfig) + protocol, address, err := m.resolveProtocolAndAddress(candidateConfig) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err)) continue } - db, err := sql.Open("mysql", dsn) + plans, err := buildMySQLCompatibleConnectPlans(candidateConfig, protocol, address, candidateConfig.Database) if err != nil { - errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) + errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err)) continue } - timeout := getConnectTimeout(candidateConfig) - ctx, cancel := utils.ContextWithTimeout(timeout) - pingErr := db.PingContext(ctx) - cancel() - if pingErr != nil { - _ = db.Close() - errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr)) - continue - } + for _, plan := range plans { + db, err := sql.Open("mysql", plan.dsn) + if err != nil { + if len(plans) > 1 || plan.label != mySQLCompatPlanDefaultLabel { + errorDetails = append(errorDetails, fmt.Sprintf("%s [%s] 打开失败: %v", address, plan.label, err)) + } else { + errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) + } + continue + } - m.conn = db - m.pingTimeout = timeout - return nil + timeout := getConnectTimeout(candidateConfig) + ctx, cancel := utils.ContextWithTimeout(timeout) + pingErr := db.PingContext(ctx) + cancel() + if pingErr != nil { + _ = db.Close() + if len(plans) > 1 || plan.label != mySQLCompatPlanDefaultLabel { + errorDetails = append(errorDetails, fmt.Sprintf("%s [%s] 验证失败: %v", address, plan.label, pingErr)) + } else { + errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr)) + } + continue + } + + if plan.label != mySQLCompatPlanDefaultLabel { + logger.Warnf("MySQL 兼容回退生效:地址=%s 模式=%s", address, plan.label) + } + + m.conn = db + m.pingTimeout = timeout + return nil + } } if len(errorDetails) == 0 {