From 358d799af8b888912583cb7e29c783173bcbbfca Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sun, 24 May 2026 10:59:52 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(mysql):=20=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=20allowMultiQueries=20=E8=BF=9E=E6=8E=A5=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 JDBC allowMultiQueries 参数映射为 MySQL driver 支持的 multiStatements - 修复自定义 MySQL DSN 透传导致旧版本 MySQL 连接失败的问题 - 更新 MySQL 兼容 driver-agent revision Refs #441 --- internal/db/custom_impl.go | 3 + internal/db/custom_impl_test.go | 78 +++++++++++++++++++++ internal/db/driver_agent_revisions_gen.go | 10 +-- internal/db/mysql_connection_params_test.go | 24 +++++++ internal/db/mysql_impl.go | 64 +++++++++++++++++ tools/generate-driver-agent-revisions.sh | 1 + 6 files changed, 175 insertions(+), 5 deletions(-) diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go index dc48398..58badeb 100644 --- a/internal/db/custom_impl.go +++ b/internal/db/custom_impl.go @@ -23,6 +23,9 @@ func (c *CustomDB) Connect(config connection.ConnectionConfig) error { if driver == "" || dsn == "" { return fmt.Errorf("driver and dsn are required for custom connection") } + if strings.EqualFold(driver, "mysql") { + dsn = normalizeMySQLRawDSNCompatibilityParams(dsn) + } // Verify driver is registered (implicit check by sql.Open) // We might not need explicit check, sql.Open will fail or Ping will fail if driver not found. diff --git a/internal/db/custom_impl_test.go b/internal/db/custom_impl_test.go index 3634869..00e6559 100644 --- a/internal/db/custom_impl_test.go +++ b/internal/db/custom_impl_test.go @@ -1,12 +1,43 @@ package db import ( + "database/sql" + "database/sql/driver" "strings" "testing" "GoNavi-Wails/internal/connection" ) +const customMySQLDSNRecordingDriverName = "custom-mysql-dsn-recording" + +var customMySQLDSNRecordingLastDSN string + +type customMySQLDSNRecordingDriver struct{} + +func (d customMySQLDSNRecordingDriver) Open(name string) (driver.Conn, error) { + customMySQLDSNRecordingLastDSN = name + return customMySQLDSNRecordingConn{}, nil +} + +type customMySQLDSNRecordingConn struct{} + +func (c customMySQLDSNRecordingConn) Prepare(query string) (driver.Stmt, error) { + return nil, driver.ErrSkip +} + +func (c customMySQLDSNRecordingConn) Close() error { + return nil +} + +func (c customMySQLDSNRecordingConn) Begin() (driver.Tx, error) { + return nil, driver.ErrSkip +} + +func init() { + sql.Register(customMySQLDSNRecordingDriverName, customMySQLDSNRecordingDriver{}) +} + func TestCustomDBConnectReportsUnsupportedODBCDriverName(t *testing.T) { db := &CustomDB{} @@ -52,3 +83,50 @@ func TestCustomDBConnectReportsUnregisteredGoDriver(t *testing.T) { } } } + +func TestNormalizeMySQLRawDSNCompatibilityParamsMapsAllowMultiQueries(t *testing.T) { + got := normalizeMySQLRawDSNCompatibilityParams( + "root:pass@tcp(127.0.0.1:3306)/app?charset=utf8mb4&allowMultiQueries=true#debug", + ) + if strings.Contains(got, "allowMultiQueries") { + t.Fatalf("allowMultiQueries should not remain in DSN: %s", got) + } + if !strings.Contains(got, "multiStatements=true") { + t.Fatalf("allowMultiQueries=true should map to multiStatements=true: %s", got) + } + if !strings.HasSuffix(got, "#debug") { + t.Fatalf("fragment should be preserved: %s", got) + } +} + +func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesExplicitMultiStatements(t *testing.T) { + got := normalizeMySQLRawDSNCompatibilityParams( + "root:pass@tcp(127.0.0.1:3306)/app?allowMultiQueries=true&multiStatements=false", + ) + if strings.Contains(got, "allowMultiQueries") { + t.Fatalf("allowMultiQueries should not remain in DSN: %s", got) + } + if !strings.Contains(got, "multiStatements=false") { + t.Fatalf("explicit multiStatements should win: %s", got) + } +} + +func TestCustomDBOnlyNormalizesBuiltInMySQLDriverDSN(t *testing.T) { + customMySQLDSNRecordingLastDSN = "" + rawDSN := "root:pass@tcp(127.0.0.1:3306)/app?allowMultiQueries=true" + + db := &CustomDB{} + err := db.Connect(connection.ConnectionConfig{ + Driver: customMySQLDSNRecordingDriverName, + DSN: rawDSN, + }) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + t.Cleanup(func() { + _ = db.Close() + }) + if customMySQLDSNRecordingLastDSN != rawDSN { + t.Fatalf("non-mysql custom driver DSN should stay untouched, got %q", customMySQLDSNRecordingLastDSN) + } +} diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index 65360a8..264e525 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -4,11 +4,11 @@ package db func init() { optionalDriverAgentRevisions = map[string]string{ - "mariadb": "src-4e1ec648c70c87ea", - "oceanbase": "src-8e445fc4899d850f", - "diros": "src-74927b3809258666", - "starrocks": "src-4ea05ce44321c17b", - "sphinx": "src-269bd60a34df47d3", + "mariadb": "src-0a4176f4b5743323", + "oceanbase": "src-e996325fd6d52648", + "diros": "src-cc11b882e28fa5d4", + "starrocks": "src-83a6d81c91c7f5c8", + "sphinx": "src-a70c2cd4d223dac2", "sqlserver": "src-84553484c72e7253", "sqlite": "src-762863d48f653b89", "duckdb": "src-3e551d777ae96d8d", diff --git a/internal/db/mysql_connection_params_test.go b/internal/db/mysql_connection_params_test.go index 7c457d2..97fdd97 100644 --- a/internal/db/mysql_connection_params_test.go +++ b/internal/db/mysql_connection_params_test.go @@ -270,6 +270,30 @@ func TestMySQLDSN_MapsAdditionalJDBCAliases(t *testing.T) { } } +func TestMySQLDSN_MapsAllowMultiQueriesTrueWithoutLeakingKey(t *testing.T) { + t.Parallel() + + m := &MySQLDB{} + dsn, err := m.getDSN(connection.ConnectionConfig{ + Host: "db.local", + Port: 3306, + User: "root", + Database: "app", + ConnectionParams: "allowMultiQueries=true", + }) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + + query := parseMySQLDSNQueryForTest(t, dsn) + if got := query.Get("multiStatements"); got != "true" { + t.Fatalf("allowMultiQueries=true should map to multiStatements=true, got=%q; query=%v", got, query) + } + if _, exists := query["allowMultiQueries"]; exists { + t.Fatalf("allowMultiQueries should not be passed to Go MySQL driver: %v", query) + } +} + func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) { t.Parallel() diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index aefc0e5..70eb6f8 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -319,6 +319,70 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre ), nil } +func normalizeMySQLRawDSNCompatibilityParams(raw string) string { + text := strings.TrimSpace(raw) + queryIndex := strings.Index(text, "?") + if text == "" || queryIndex < 0 { + return raw + } + + prefix := text[:queryIndex] + queryText := text[queryIndex+1:] + suffix := "" + if fragmentIndex := strings.Index(queryText, "#"); fragmentIndex >= 0 { + suffix = queryText[fragmentIndex:] + queryText = queryText[:fragmentIndex] + } + values, err := url.ParseQuery(queryText) + if err != nil { + return raw + } + + changed := false + explicitMultiStatements := "" + hasExplicitMultiStatements := false + allowMultiQueries := "" + hasAllowMultiQueries := false + + for key, items := range values { + switch strings.ToLower(strings.TrimSpace(key)) { + case "multistatements": + delete(values, key) + changed = true + for _, item := range items { + if enabled, ok := parseMySQLBoolParam(item); ok { + explicitMultiStatements = strconv.FormatBool(enabled) + hasExplicitMultiStatements = true + } + } + case "allowmultiqueries": + delete(values, key) + changed = true + for _, item := range items { + if enabled, ok := parseMySQLBoolParam(item); ok { + allowMultiQueries = strconv.FormatBool(enabled) + hasAllowMultiQueries = true + } + } + } + } + + if hasExplicitMultiStatements { + values.Set("multiStatements", explicitMultiStatements) + } else if hasAllowMultiQueries { + values.Set("multiStatements", allowMultiQueries) + } + + if !changed { + return raw + } + encoded := values.Encode() + if encoded == "" { + return prefix + suffix + } + return prefix + "?" + encoded + suffix +} + func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) { text := strings.TrimSpace(raw) if text == "" { diff --git a/tools/generate-driver-agent-revisions.sh b/tools/generate-driver-agent-revisions.sh index fda338b..9fccef5 100755 --- a/tools/generate-driver-agent-revisions.sh +++ b/tools/generate-driver-agent-revisions.sh @@ -102,6 +102,7 @@ internal/db/timeout.go) case "$driver:$identity" in mariadb:internal/db/mariadb_impl.go|\ +mariadb:internal/db/mysql_impl.go|\ oceanbase:internal/db/oceanbase_impl.go|\ oceanbase:internal/db/oracle_impl.go|\ oceanbase:internal/db/mysql_impl.go|\