🐛 fix(mysql): 兼容 allowMultiQueries 连接参数

- 将 JDBC allowMultiQueries 参数映射为 MySQL driver 支持的 multiStatements

- 修复自定义 MySQL DSN 透传导致旧版本 MySQL 连接失败的问题

- 更新 MySQL 兼容 driver-agent revision

Refs #441
This commit is contained in:
Syngnat
2026-05-24 10:59:52 +08:00
parent cf0a216329
commit 358d799af8
6 changed files with 175 additions and 5 deletions

View File

@@ -23,6 +23,9 @@ func (c *CustomDB) Connect(config connection.ConnectionConfig) error {
if driver == "" || dsn == "" {
return fmt.Errorf("driver and dsn are required for custom connection")
}
if strings.EqualFold(driver, "mysql") {
dsn = normalizeMySQLRawDSNCompatibilityParams(dsn)
}
// Verify driver is registered (implicit check by sql.Open)
// We might not need explicit check, sql.Open will fail or Ping will fail if driver not found.

View File

@@ -1,12 +1,43 @@
package db
import (
"database/sql"
"database/sql/driver"
"strings"
"testing"
"GoNavi-Wails/internal/connection"
)
const customMySQLDSNRecordingDriverName = "custom-mysql-dsn-recording"
var customMySQLDSNRecordingLastDSN string
type customMySQLDSNRecordingDriver struct{}
func (d customMySQLDSNRecordingDriver) Open(name string) (driver.Conn, error) {
customMySQLDSNRecordingLastDSN = name
return customMySQLDSNRecordingConn{}, nil
}
type customMySQLDSNRecordingConn struct{}
func (c customMySQLDSNRecordingConn) Prepare(query string) (driver.Stmt, error) {
return nil, driver.ErrSkip
}
func (c customMySQLDSNRecordingConn) Close() error {
return nil
}
func (c customMySQLDSNRecordingConn) Begin() (driver.Tx, error) {
return nil, driver.ErrSkip
}
func init() {
sql.Register(customMySQLDSNRecordingDriverName, customMySQLDSNRecordingDriver{})
}
func TestCustomDBConnectReportsUnsupportedODBCDriverName(t *testing.T) {
db := &CustomDB{}
@@ -52,3 +83,50 @@ func TestCustomDBConnectReportsUnregisteredGoDriver(t *testing.T) {
}
}
}
func TestNormalizeMySQLRawDSNCompatibilityParamsMapsAllowMultiQueries(t *testing.T) {
got := normalizeMySQLRawDSNCompatibilityParams(
"root:pass@tcp(127.0.0.1:3306)/app?charset=utf8mb4&allowMultiQueries=true#debug",
)
if strings.Contains(got, "allowMultiQueries") {
t.Fatalf("allowMultiQueries should not remain in DSN: %s", got)
}
if !strings.Contains(got, "multiStatements=true") {
t.Fatalf("allowMultiQueries=true should map to multiStatements=true: %s", got)
}
if !strings.HasSuffix(got, "#debug") {
t.Fatalf("fragment should be preserved: %s", got)
}
}
func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesExplicitMultiStatements(t *testing.T) {
got := normalizeMySQLRawDSNCompatibilityParams(
"root:pass@tcp(127.0.0.1:3306)/app?allowMultiQueries=true&multiStatements=false",
)
if strings.Contains(got, "allowMultiQueries") {
t.Fatalf("allowMultiQueries should not remain in DSN: %s", got)
}
if !strings.Contains(got, "multiStatements=false") {
t.Fatalf("explicit multiStatements should win: %s", got)
}
}
func TestCustomDBOnlyNormalizesBuiltInMySQLDriverDSN(t *testing.T) {
customMySQLDSNRecordingLastDSN = ""
rawDSN := "root:pass@tcp(127.0.0.1:3306)/app?allowMultiQueries=true"
db := &CustomDB{}
err := db.Connect(connection.ConnectionConfig{
Driver: customMySQLDSNRecordingDriverName,
DSN: rawDSN,
})
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
t.Cleanup(func() {
_ = db.Close()
})
if customMySQLDSNRecordingLastDSN != rawDSN {
t.Fatalf("non-mysql custom driver DSN should stay untouched, got %q", customMySQLDSNRecordingLastDSN)
}
}

View File

@@ -4,11 +4,11 @@ package db
func init() {
optionalDriverAgentRevisions = map[string]string{
"mariadb": "src-4e1ec648c70c87ea",
"oceanbase": "src-8e445fc4899d850f",
"diros": "src-74927b3809258666",
"starrocks": "src-4ea05ce44321c17b",
"sphinx": "src-269bd60a34df47d3",
"mariadb": "src-0a4176f4b5743323",
"oceanbase": "src-e996325fd6d52648",
"diros": "src-cc11b882e28fa5d4",
"starrocks": "src-83a6d81c91c7f5c8",
"sphinx": "src-a70c2cd4d223dac2",
"sqlserver": "src-84553484c72e7253",
"sqlite": "src-762863d48f653b89",
"duckdb": "src-3e551d777ae96d8d",

View File

@@ -270,6 +270,30 @@ func TestMySQLDSN_MapsAdditionalJDBCAliases(t *testing.T) {
}
}
func TestMySQLDSN_MapsAllowMultiQueriesTrueWithoutLeakingKey(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "db.local",
Port: 3306,
User: "root",
Database: "app",
ConnectionParams: "allowMultiQueries=true",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
query := parseMySQLDSNQueryForTest(t, dsn)
if got := query.Get("multiStatements"); got != "true" {
t.Fatalf("allowMultiQueries=true should map to multiStatements=true, got=%q; query=%v", got, query)
}
if _, exists := query["allowMultiQueries"]; exists {
t.Fatalf("allowMultiQueries should not be passed to Go MySQL driver: %v", query)
}
}
func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) {
t.Parallel()

View File

@@ -319,6 +319,70 @@ func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, addre
), nil
}
func normalizeMySQLRawDSNCompatibilityParams(raw string) string {
text := strings.TrimSpace(raw)
queryIndex := strings.Index(text, "?")
if text == "" || queryIndex < 0 {
return raw
}
prefix := text[:queryIndex]
queryText := text[queryIndex+1:]
suffix := ""
if fragmentIndex := strings.Index(queryText, "#"); fragmentIndex >= 0 {
suffix = queryText[fragmentIndex:]
queryText = queryText[:fragmentIndex]
}
values, err := url.ParseQuery(queryText)
if err != nil {
return raw
}
changed := false
explicitMultiStatements := ""
hasExplicitMultiStatements := false
allowMultiQueries := ""
hasAllowMultiQueries := false
for key, items := range values {
switch strings.ToLower(strings.TrimSpace(key)) {
case "multistatements":
delete(values, key)
changed = true
for _, item := range items {
if enabled, ok := parseMySQLBoolParam(item); ok {
explicitMultiStatements = strconv.FormatBool(enabled)
hasExplicitMultiStatements = true
}
}
case "allowmultiqueries":
delete(values, key)
changed = true
for _, item := range items {
if enabled, ok := parseMySQLBoolParam(item); ok {
allowMultiQueries = strconv.FormatBool(enabled)
hasAllowMultiQueries = true
}
}
}
}
if hasExplicitMultiStatements {
values.Set("multiStatements", explicitMultiStatements)
} else if hasAllowMultiQueries {
values.Set("multiStatements", allowMultiQueries)
}
if !changed {
return raw
}
encoded := values.Encode()
if encoded == "" {
return prefix + suffix
}
return prefix + "?" + encoded + suffix
}
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
text := strings.TrimSpace(raw)
if text == "" {

View File

@@ -102,6 +102,7 @@ internal/db/timeout.go)
case "$driver:$identity" in
mariadb:internal/db/mariadb_impl.go|\
mariadb:internal/db/mysql_impl.go|\
oceanbase:internal/db/oceanbase_impl.go|\
oceanbase:internal/db/oracle_impl.go|\
oceanbase:internal/db/mysql_impl.go|\