From c45961f0277fff76868ac5dc56f24f2c2e90fd28 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 9 Jun 2026 14:13:35 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(db):=20=E4=BF=9D=E7=95=99?= =?UTF-8?q?=E5=A4=9A=E5=86=99=E8=AF=AD=E5=8F=A5=E7=BB=93=E6=9E=9C=E5=B9=B6?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20MySQL=20=E5=AD=97=E7=AC=A6=E9=9B=86?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 多条写语句改为逐条返回 affectedRows,避免只显示最后一条结果 - 为写语句结果补充 statementIndex,保持语句与结果映射 - 保留 MySQL charset fallback 逗号,避免驱动解析成 %2C --- internal/app/methods_db.go | 12 ++-- internal/app/methods_db_multi_test.go | 70 +++++++++++++-------- internal/db/custom_impl_test.go | 12 ++++ internal/db/mysql_connection_params_test.go | 44 +++++++++++++ internal/db/mysql_impl.go | 38 ++++++++++- 5 files changed, 144 insertions(+), 32 deletions(-) diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index f0a2b38..033a595 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -827,8 +827,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu } defer closeExecTarget() - // 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返 - // 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动 + // 单条写语句且驱动支持批量 Exec 时,可复用批量路径。 + // 多条写语句必须逐条返回结果;部分驱动对多语句 Exec 仅暴露最后一条 RowsAffected, + // 会导致前面语句已成功执行但结果页只剩一个写入结果。 if !allReadOnly { allWrite := true containsPLSQLBlock := false @@ -840,7 +841,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu containsPLSQLBlock = true } } - if allWrite && !containsPLSQLBlock { + if allWrite && !containsPLSQLBlock && len(statements) == 1 { batcher := sessionBatchTarget if batcher == nil { if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok { @@ -987,8 +988,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID} } resultSets = append(resultSets, connection.ResultSetData{ - Rows: []map[string]interface{}{{"affectedRows": affected}}, - Columns: []string{"affectedRows"}, + Rows: []map[string]interface{}{{"affectedRows": affected}}, + Columns: []string{"affectedRows"}, + StatementIndex: idx + 1, }) } diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 6b1ac6d..012ab2c 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -10,18 +10,19 @@ import ( ) type fakeBatchWriteDB struct { - batchCalls int - execCalls int - execQueries []string - lastQuery string - lastCtx context.Context - queryCalls int - queryMap map[string][]map[string]interface{} - fieldMap map[string][]string - messageMap map[string][]string - multiResult map[string][]connection.ResultSetData - queryErr map[string]error - session *fakeBatchWriteSession + batchCalls int + execCalls int + execQueries []string + lastQuery string + lastCtx context.Context + queryCalls int + queryMap map[string][]map[string]interface{} + fieldMap map[string][]string + messageMap map[string][]string + multiResult map[string][]connection.ResultSetData + queryErr map[string]error + execAffected map[string]int64 + session *fakeBatchWriteSession } func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error { @@ -52,6 +53,9 @@ func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interfa func (f *fakeBatchWriteDB) Exec(query string) (int64, error) { f.execCalls++ f.execQueries = append(f.execQueries, query) + if affected, ok := f.execAffected[query]; ok { + return affected, nil + } return 1, nil } @@ -91,6 +95,9 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64 f.lastCtx = ctx f.execCalls++ f.execQueries = append(f.execQueries, query) + if affected, ok := f.execAffected[query]; ok { + return affected, nil + } return 1, nil } @@ -440,13 +447,20 @@ func TestDBQueryWithCancel_DuckDBQueriesDoNotInheritConnectTimeout(t *testing.T) } } -func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) { +func TestDBQueryMultiPreservesPerStatementResultsForMultipleWriteStatements(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { newDatabaseFunc = originalNewDatabaseFunc }) - fakeDB := &fakeBatchWriteDB{} + firstStmt := "DELETE FROM assets_asset" + secondStmt := "DELETE FROM assets_assetcategory" + fakeDB := &fakeBatchWriteDB{ + execAffected: map[string]int64{ + firstStmt: 5, + secondStmt: 10, + }, + } newDatabaseFunc = func(dbType string) (db.Database, error) { return fakeDB, nil } @@ -458,31 +472,37 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) { Port: 1433, User: "sa", } - query := "INSERT INTO demo(id) VALUES (1);\nINSERT INTO demo(id) VALUES (2);" + query := firstStmt + ";\n" + secondStmt + ";" result := app.DBQueryMulti(config, "testdb", query, "batch-write-test") if !result.Success { t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) } - if fakeDB.batchCalls != 1 { - t.Fatalf("expected batch path to run once, got %d", fakeDB.batchCalls) + if fakeDB.batchCalls != 0 { + t.Fatalf("expected multiple write statements to skip batch path so each result can be preserved, got %d", fakeDB.batchCalls) } - if fakeDB.execCalls != 0 { - t.Fatalf("expected sequential exec path to be skipped, got execCalls=%d", fakeDB.execCalls) + if fakeDB.execCalls != 2 { + t.Fatalf("expected sequential exec path to run twice, got execCalls=%d", fakeDB.execCalls) } - if fakeDB.lastQuery != query { - t.Fatalf("expected batch query to stay intact, got %q", fakeDB.lastQuery) + if len(fakeDB.execQueries) != 2 || fakeDB.execQueries[0] != firstStmt || fakeDB.execQueries[1] != secondStmt { + t.Fatalf("expected sequential execs to preserve statement order, got %#v", fakeDB.execQueries) } resultSets, ok := result.Data.([]connection.ResultSetData) if !ok { t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) } - if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 { - t.Fatalf("expected one affectedRows result set, got %#v", resultSets) + if len(resultSets) != 2 { + t.Fatalf("expected one affectedRows result set per statement, got %#v", resultSets) } - if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(500) { - t.Fatalf("expected affectedRows=500, got %#v", got) + if len(resultSets[0].Rows) != 1 || len(resultSets[1].Rows) != 1 { + t.Fatalf("expected both result sets to contain a single affectedRows row, got %#v", resultSets) + } + if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(5) { + t.Fatalf("expected first affectedRows=5, got %#v", got) + } + if got := resultSets[1].Rows[0]["affectedRows"]; got != int64(10) { + t.Fatalf("expected second affectedRows=10, got %#v", got) } } diff --git a/internal/db/custom_impl_test.go b/internal/db/custom_impl_test.go index 00e6559..4f983ed 100644 --- a/internal/db/custom_impl_test.go +++ b/internal/db/custom_impl_test.go @@ -111,6 +111,18 @@ func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesExplicitMultiStatements } } +func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesCharsetFallbackComma(t *testing.T) { + got := normalizeMySQLRawDSNCompatibilityParams( + "root:pass@tcp(127.0.0.1:3306)/app?charset=utf8mb4,utf8&allowMultiQueries=true", + ) + if strings.Contains(got, "%2C") || strings.Contains(got, "%2c") { + t.Fatalf("charset fallback comma should stay unescaped for mysql driver, got %q", got) + } + if !strings.Contains(got, "charset=utf8mb4,utf8") { + t.Fatalf("charset fallback list should be preserved, got %q", got) + } +} + func TestCustomDBOnlyNormalizesBuiltInMySQLDriverDSN(t *testing.T) { customMySQLDSNRecordingLastDSN = "" rawDSN := "root:pass@tcp(127.0.0.1:3306)/app?allowMultiQueries=true" diff --git a/internal/db/mysql_connection_params_test.go b/internal/db/mysql_connection_params_test.go index 41453e4..9d852c3 100644 --- a/internal/db/mysql_connection_params_test.go +++ b/internal/db/mysql_connection_params_test.go @@ -3,10 +3,13 @@ package db import ( "database/sql" "net/url" + "reflect" "strings" "testing" "GoNavi-Wails/internal/connection" + + mysql "github.com/go-sql-driver/mysql" ) func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values { @@ -22,6 +25,26 @@ func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values { return values } +func parseMySQLDriverCharsetsForTest(t *testing.T, dsn string) []string { + t.Helper() + + cfg, err := mysql.ParseDSN(dsn) + if err != nil { + t.Fatalf("mysql ParseDSN failed: %v", err) + } + + field := reflect.ValueOf(cfg).Elem().FieldByName("charsets") + if !field.IsValid() { + t.Fatal("mysql.Config missing internal charsets field") + } + + charsets := make([]string, field.Len()) + for i := 0; i < field.Len(); i++ { + charsets[i] = field.Index(i).String() + } + return charsets +} + func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) { t.Parallel() @@ -398,6 +421,27 @@ func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) { _ = db.Close() } +func TestMySQLDSN_DefaultCharsetFallbackListRemainsDriverCompatible(t *testing.T) { + t.Parallel() + + m := &MySQLDB{} + dsn, err := m.getDSN(connection.ConnectionConfig{ + Host: "127.0.0.1", + Port: 3306, + User: "root", + Database: "app", + }) + if err != nil { + t.Fatalf("getDSN failed: %v", err) + } + + got := parseMySQLDriverCharsetsForTest(t, dsn) + want := []string{"utf8mb4", "utf8"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("driver should parse charset fallback list, got=%v want=%v dsn=%q", got, want, dsn) + } +} + func TestMySQLDSN_URIParamsAndExplicitParamsPrecedence(t *testing.T) { t.Parallel() diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index d280f41..61e1e1f 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -372,12 +372,46 @@ func buildMySQLCompatibleDSNWithOptions(config connection.ConnectionConfig, prot mergeMySQLConnectionParams(params, parsed.Query()) } mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams)) + encodedParams := encodeMySQLDSNQuery(params) return fmt.Sprintf( "%s:%s@%s(%s)/%s?%s", - config.User, config.Password, protocol, address, database, params.Encode(), + config.User, config.Password, protocol, address, database, encodedParams, ), nil } +func encodeMySQLDSNQuery(params url.Values) string { + if len(params) == 0 { + return "" + } + + keys := make([]string, 0, len(params)) + for key := range params { + keys = append(keys, key) + } + sort.Strings(keys) + + var builder strings.Builder + for _, key := range keys { + escapedKey := url.QueryEscape(key) + values := params[key] + for _, value := range values { + if builder.Len() > 0 { + builder.WriteByte('&') + } + builder.WriteString(escapedKey) + builder.WriteByte('=') + escapedValue := url.QueryEscape(value) + if strings.EqualFold(strings.TrimSpace(key), "charset") { + escapedValue = strings.ReplaceAll(escapedValue, "%2C", ",") + escapedValue = strings.ReplaceAll(escapedValue, "%2c", ",") + } + builder.WriteString(escapedValue) + } + } + + return builder.String() +} + func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) { defaultMultiStatements := true return buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{ @@ -475,7 +509,7 @@ func normalizeMySQLRawDSNCompatibilityParams(raw string) string { if !changed { return raw } - encoded := values.Encode() + encoded := encodeMySQLDSNQuery(values) if encoded == "" { return prefix + suffix }