mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 18:39:54 +08:00
🐛 fix(mysql): 修复 GDB 连接参数不兼容导致的握手失败
- 优化 MySQL 兼容 DSN 默认参数 - 在连接验证阶段增加 multiStatements 兼容回退 - 补充相关单元测试覆盖 Refs #543
This commit is contained in:
@@ -54,6 +54,33 @@ func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_EmptyDatabaseUsesCompatibilityDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "gdb.local",
|
||||
Port: 1523,
|
||||
User: "glzc",
|
||||
Password: "secret",
|
||||
Timeout: 30,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(dsn, "@tcp(gdb.local:1523)/?") {
|
||||
t.Fatalf("empty database should still keep DSN slash separator, got=%q", dsn)
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, dsn)
|
||||
if got := query.Get("charset"); got != "utf8mb4,utf8" {
|
||||
t.Fatalf("default charset should fall back from utf8mb4 to utf8, got=%q", got)
|
||||
}
|
||||
if got := query.Get("multiStatements"); got != "true" {
|
||||
t.Fatalf("default multiStatements should remain enabled, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_MapsCommonJDBCParamsWithoutLeakingUnsupportedKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -294,6 +321,61 @@ func TestMySQLDSN_MapsAllowMultiQueriesTrueWithoutLeakingKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMySQLCompatibleConnectPlans_AddsHandshakeFallbackWhenMultiStatementsImplicit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
plans, err := buildMySQLCompatibleConnectPlans(connection.ConnectionConfig{
|
||||
Host: "gdb.local",
|
||||
Port: 1523,
|
||||
User: "glzc",
|
||||
Timeout: 30,
|
||||
}, "tcp", "gdb.local:1523", "")
|
||||
if err != nil {
|
||||
t.Fatalf("buildMySQLCompatibleConnectPlans failed: %v", err)
|
||||
}
|
||||
if len(plans) != 2 {
|
||||
t.Fatalf("expected default plan plus compatibility fallback, got %d", len(plans))
|
||||
}
|
||||
|
||||
defaultQuery := parseMySQLDSNQueryForTest(t, plans[0].dsn)
|
||||
if got := defaultQuery.Get("multiStatements"); got != "true" {
|
||||
t.Fatalf("default plan should keep multiStatements enabled, got=%q", got)
|
||||
}
|
||||
if got := defaultQuery.Get("charset"); got != "utf8mb4,utf8" {
|
||||
t.Fatalf("default plan should use utf8 fallback charset, got=%q", got)
|
||||
}
|
||||
|
||||
fallbackQuery := parseMySQLDSNQueryForTest(t, plans[1].dsn)
|
||||
if got := fallbackQuery.Get("multiStatements"); got != "false" {
|
||||
t.Fatalf("fallback plan should disable multiStatements, got=%q", got)
|
||||
}
|
||||
if got := fallbackQuery.Get("charset"); got != "utf8mb4,utf8" {
|
||||
t.Fatalf("fallback plan should preserve charset fallback, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMySQLCompatibleConnectPlans_RespectsExplicitMultiStatementsChoice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
plans, err := buildMySQLCompatibleConnectPlans(connection.ConnectionConfig{
|
||||
Host: "db.local",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
ConnectionParams: "allowMultiQueries=false",
|
||||
}, "tcp", "db.local:3306", "app")
|
||||
if err != nil {
|
||||
t.Fatalf("buildMySQLCompatibleConnectPlans failed: %v", err)
|
||||
}
|
||||
if len(plans) != 1 {
|
||||
t.Fatalf("explicit multiStatements choice should skip compatibility fallback, got %d plans", len(plans))
|
||||
}
|
||||
|
||||
query := parseMySQLDSNQueryForTest(t, plans[0].dsn)
|
||||
if got := query.Get("multiStatements"); got != "false" {
|
||||
t.Fatalf("explicit allowMultiQueries=false should be preserved, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -278,6 +278,53 @@ func mergeMySQLConnectionParams(params url.Values, values url.Values) {
|
||||
}
|
||||
}
|
||||
|
||||
type mySQLCompatibleDSNOptions struct {
|
||||
defaultCharset string
|
||||
defaultMultiStatements *bool
|
||||
}
|
||||
|
||||
type mySQLCompatibleConnectPlan struct {
|
||||
label string
|
||||
dsn string
|
||||
}
|
||||
|
||||
const (
|
||||
mySQLCompatPlanDefaultLabel = "默认兼容参数"
|
||||
mySQLCompatPlanDisableMultiStatementsLabel = "禁用 multiStatements 兼容重试"
|
||||
)
|
||||
|
||||
func hasMySQLConnectionParam(config connection.ConnectionConfig, names ...string) bool {
|
||||
if len(names) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
targets := make(map[string]struct{}, len(names))
|
||||
for _, name := range names {
|
||||
normalized := strings.ToLower(strings.TrimSpace(name))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
targets[normalized] = struct{}{}
|
||||
}
|
||||
if len(targets) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
hasMatchingKey := func(values url.Values) bool {
|
||||
for key := range values {
|
||||
if _, ok := targets[strings.ToLower(strings.TrimSpace(key))]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "mariadb", "doris", "diros", "oceanbase", "starrocks"); ok && hasMatchingKey(parsed.Query()) {
|
||||
return true
|
||||
}
|
||||
return hasMatchingKey(mysqlConnectionParamsFromText(config.ConnectionParams))
|
||||
}
|
||||
|
||||
func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, error) {
|
||||
mode := resolveMySQLTLSMode(config)
|
||||
if mode == "false" || !hasTLSCertificatePaths(config) {
|
||||
@@ -297,14 +344,18 @@ func resolveMySQLTLSParam(config connection.ConnectionConfig) (string, bool, err
|
||||
return name, normalizeSSLModeValue(config.SSLMode) == sslModePreferred, nil
|
||||
}
|
||||
|
||||
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
|
||||
func buildMySQLCompatibleDSNWithOptions(config connection.ConnectionConfig, protocol, address, database string, options mySQLCompatibleDSNOptions) (string, error) {
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode, allowFallbackToPlaintext, err := resolveMySQLTLSParam(config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("charset", "utf8mb4")
|
||||
defaultCharset := strings.TrimSpace(options.defaultCharset)
|
||||
if defaultCharset == "" {
|
||||
defaultCharset = "utf8mb4,utf8"
|
||||
}
|
||||
params.Set("charset", defaultCharset)
|
||||
params.Set("parseTime", "True")
|
||||
params.Set("loc", "Local")
|
||||
params.Set("timeout", fmt.Sprintf("%ds", timeout))
|
||||
@@ -312,7 +363,11 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre
|
||||
if allowFallbackToPlaintext {
|
||||
params.Set("allowFallbackToPlaintext", "true")
|
||||
}
|
||||
params.Set("multiStatements", "true")
|
||||
defaultMultiStatements := true
|
||||
if options.defaultMultiStatements != nil {
|
||||
defaultMultiStatements = *options.defaultMultiStatements
|
||||
}
|
||||
params.Set("multiStatements", strconv.FormatBool(defaultMultiStatements))
|
||||
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok {
|
||||
mergeMySQLConnectionParams(params, parsed.Query())
|
||||
}
|
||||
@@ -323,6 +378,46 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre
|
||||
), nil
|
||||
}
|
||||
|
||||
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
|
||||
defaultMultiStatements := true
|
||||
return buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{
|
||||
defaultCharset: "utf8mb4,utf8",
|
||||
defaultMultiStatements: &defaultMultiStatements,
|
||||
})
|
||||
}
|
||||
|
||||
func buildMySQLCompatibleConnectPlans(config connection.ConnectionConfig, protocol, address, database string) ([]mySQLCompatibleConnectPlan, error) {
|
||||
defaultDSN, err := buildMySQLCompatibleDSN(config, protocol, address, database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plans := []mySQLCompatibleConnectPlan{{
|
||||
label: mySQLCompatPlanDefaultLabel,
|
||||
dsn: defaultDSN,
|
||||
}}
|
||||
|
||||
if hasMySQLConnectionParam(config, "multiStatements", "allowMultiQueries") {
|
||||
return plans, nil
|
||||
}
|
||||
|
||||
disabled := false
|
||||
fallbackDSN, err := buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{
|
||||
defaultCharset: "utf8mb4,utf8",
|
||||
defaultMultiStatements: &disabled,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fallbackDSN == defaultDSN {
|
||||
return plans, nil
|
||||
}
|
||||
|
||||
return append(plans, mySQLCompatibleConnectPlan{
|
||||
label: mySQLCompatPlanDisableMultiStatementsLabel,
|
||||
dsn: fallbackDSN,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func normalizeMySQLRawDSNCompatibilityParams(raw string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
queryIndex := strings.Index(text, "?")
|
||||
@@ -582,20 +677,27 @@ func collectMySQLAddresses(config connection.ConnectionConfig) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
database := config.Database
|
||||
func (m *MySQLDB) resolveProtocolAndAddress(config connection.ConnectionConfig) (string, string, error) {
|
||||
protocol := "tcp"
|
||||
address := normalizeMySQLAddress(config.Host, config.Port)
|
||||
|
||||
if config.UseSSH {
|
||||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
return "", "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||||
}
|
||||
protocol = netName
|
||||
}
|
||||
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, database)
|
||||
return protocol, address, nil
|
||||
}
|
||||
|
||||
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||||
protocol, address, err := m.resolveProtocolAndAddress(config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buildMySQLCompatibleDSN(config, protocol, address, config.Database)
|
||||
}
|
||||
|
||||
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
@@ -633,30 +735,50 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
|
||||
candidateConfig.Port = port
|
||||
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
|
||||
|
||||
dsn, err := m.getDSN(candidateConfig)
|
||||
protocol, address, err := m.resolveProtocolAndAddress(candidateConfig)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
|
||||
continue
|
||||
}
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
plans, err := buildMySQLCompatibleConnectPlans(candidateConfig, protocol, address, candidateConfig.Database)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
|
||||
continue
|
||||
}
|
||||
|
||||
timeout := getConnectTimeout(candidateConfig)
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
pingErr := db.PingContext(ctx)
|
||||
cancel()
|
||||
if pingErr != nil {
|
||||
_ = db.Close()
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
|
||||
continue
|
||||
}
|
||||
for _, plan := range plans {
|
||||
db, err := sql.Open("mysql", plan.dsn)
|
||||
if err != nil {
|
||||
if len(plans) > 1 || plan.label != mySQLCompatPlanDefaultLabel {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s [%s] 打开失败: %v", address, plan.label, err))
|
||||
} else {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
m.conn = db
|
||||
m.pingTimeout = timeout
|
||||
return nil
|
||||
timeout := getConnectTimeout(candidateConfig)
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
pingErr := db.PingContext(ctx)
|
||||
cancel()
|
||||
if pingErr != nil {
|
||||
_ = db.Close()
|
||||
if len(plans) > 1 || plan.label != mySQLCompatPlanDefaultLabel {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s [%s] 验证失败: %v", address, plan.label, pingErr))
|
||||
} else {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if plan.label != mySQLCompatPlanDefaultLabel {
|
||||
logger.Warnf("MySQL 兼容回退生效:地址=%s 模式=%s", address, plan.label)
|
||||
}
|
||||
|
||||
m.conn = db
|
||||
m.pingTimeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(errorDetails) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user