🐛 fix(db): 保留多写语句结果并修复 MySQL 字符集参数

- 多条写语句改为逐条返回 affectedRows,避免只显示最后一条结果

- 为写语句结果补充 statementIndex,保持语句与结果映射

- 保留 MySQL charset fallback 逗号,避免驱动解析成 %2C
This commit is contained in:
Syngnat
2026-06-09 14:13:35 +08:00
parent a6105f4807
commit c45961f027
5 changed files with 144 additions and 32 deletions

View File

@@ -827,8 +827,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
defer closeExecTarget()
// 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
// 单条写语句且驱动支持批量 Exec 时,可复用批量路径。
// 多条写语句必须逐条返回结果;部分驱动对多语句 Exec 仅暴露最后一条 RowsAffected
// 会导致前面语句已成功执行但结果页只剩一个写入结果。
if !allReadOnly {
allWrite := true
containsPLSQLBlock := false
@@ -840,7 +841,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
containsPLSQLBlock = true
}
}
if allWrite && !containsPLSQLBlock {
if allWrite && !containsPLSQLBlock && len(statements) == 1 {
batcher := sessionBatchTarget
if batcher == nil {
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
@@ -987,8 +988,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
StatementIndex: idx + 1,
})
}

View File

@@ -10,18 +10,19 @@ import (
)
type fakeBatchWriteDB struct {
batchCalls int
execCalls int
execQueries []string
lastQuery string
lastCtx context.Context
queryCalls int
queryMap map[string][]map[string]interface{}
fieldMap map[string][]string
messageMap map[string][]string
multiResult map[string][]connection.ResultSetData
queryErr map[string]error
session *fakeBatchWriteSession
batchCalls int
execCalls int
execQueries []string
lastQuery string
lastCtx context.Context
queryCalls int
queryMap map[string][]map[string]interface{}
fieldMap map[string][]string
messageMap map[string][]string
multiResult map[string][]connection.ResultSetData
queryErr map[string]error
execAffected map[string]int64
session *fakeBatchWriteSession
}
func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error {
@@ -52,6 +53,9 @@ func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interfa
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
f.execCalls++
f.execQueries = append(f.execQueries, query)
if affected, ok := f.execAffected[query]; ok {
return affected, nil
}
return 1, nil
}
@@ -91,6 +95,9 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64
f.lastCtx = ctx
f.execCalls++
f.execQueries = append(f.execQueries, query)
if affected, ok := f.execAffected[query]; ok {
return affected, nil
}
return 1, nil
}
@@ -440,13 +447,20 @@ func TestDBQueryWithCancel_DuckDBQueriesDoNotInheritConnectTimeout(t *testing.T)
}
}
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
func TestDBQueryMultiPreservesPerStatementResultsForMultipleWriteStatements(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
fakeDB := &fakeBatchWriteDB{}
firstStmt := "DELETE FROM assets_asset"
secondStmt := "DELETE FROM assets_assetcategory"
fakeDB := &fakeBatchWriteDB{
execAffected: map[string]int64{
firstStmt: 5,
secondStmt: 10,
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
@@ -458,31 +472,37 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
Port: 1433,
User: "sa",
}
query := "INSERT INTO demo(id) VALUES (1);\nINSERT INTO demo(id) VALUES (2);"
query := firstStmt + ";\n" + secondStmt + ";"
result := app.DBQueryMulti(config, "testdb", query, "batch-write-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.batchCalls != 1 {
t.Fatalf("expected batch path to run once, got %d", fakeDB.batchCalls)
if fakeDB.batchCalls != 0 {
t.Fatalf("expected multiple write statements to skip batch path so each result can be preserved, got %d", fakeDB.batchCalls)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected sequential exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
if fakeDB.execCalls != 2 {
t.Fatalf("expected sequential exec path to run twice, got execCalls=%d", fakeDB.execCalls)
}
if fakeDB.lastQuery != query {
t.Fatalf("expected batch query to stay intact, got %q", fakeDB.lastQuery)
if len(fakeDB.execQueries) != 2 || fakeDB.execQueries[0] != firstStmt || fakeDB.execQueries[1] != secondStmt {
t.Fatalf("expected sequential execs to preserve statement order, got %#v", fakeDB.execQueries)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
t.Fatalf("expected one affectedRows result set, got %#v", resultSets)
if len(resultSets) != 2 {
t.Fatalf("expected one affectedRows result set per statement, got %#v", resultSets)
}
if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(500) {
t.Fatalf("expected affectedRows=500, got %#v", got)
if len(resultSets[0].Rows) != 1 || len(resultSets[1].Rows) != 1 {
t.Fatalf("expected both result sets to contain a single affectedRows row, got %#v", resultSets)
}
if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(5) {
t.Fatalf("expected first affectedRows=5, got %#v", got)
}
if got := resultSets[1].Rows[0]["affectedRows"]; got != int64(10) {
t.Fatalf("expected second affectedRows=10, got %#v", got)
}
}

View File

@@ -111,6 +111,18 @@ func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesExplicitMultiStatements
}
}
func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesCharsetFallbackComma(t *testing.T) {
got := normalizeMySQLRawDSNCompatibilityParams(
"root:pass@tcp(127.0.0.1:3306)/app?charset=utf8mb4,utf8&allowMultiQueries=true",
)
if strings.Contains(got, "%2C") || strings.Contains(got, "%2c") {
t.Fatalf("charset fallback comma should stay unescaped for mysql driver, got %q", got)
}
if !strings.Contains(got, "charset=utf8mb4,utf8") {
t.Fatalf("charset fallback list should be preserved, got %q", got)
}
}
func TestCustomDBOnlyNormalizesBuiltInMySQLDriverDSN(t *testing.T) {
customMySQLDSNRecordingLastDSN = ""
rawDSN := "root:pass@tcp(127.0.0.1:3306)/app?allowMultiQueries=true"

View File

@@ -3,10 +3,13 @@ package db
import (
"database/sql"
"net/url"
"reflect"
"strings"
"testing"
"GoNavi-Wails/internal/connection"
mysql "github.com/go-sql-driver/mysql"
)
func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values {
@@ -22,6 +25,26 @@ func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values {
return values
}
func parseMySQLDriverCharsetsForTest(t *testing.T, dsn string) []string {
t.Helper()
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
t.Fatalf("mysql ParseDSN failed: %v", err)
}
field := reflect.ValueOf(cfg).Elem().FieldByName("charsets")
if !field.IsValid() {
t.Fatal("mysql.Config missing internal charsets field")
}
charsets := make([]string, field.Len())
for i := 0; i < field.Len(); i++ {
charsets[i] = field.Index(i).String()
}
return charsets
}
func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) {
t.Parallel()
@@ -398,6 +421,27 @@ func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) {
_ = db.Close()
}
func TestMySQLDSN_DefaultCharsetFallbackListRemainsDriverCompatible(t *testing.T) {
t.Parallel()
m := &MySQLDB{}
dsn, err := m.getDSN(connection.ConnectionConfig{
Host: "127.0.0.1",
Port: 3306,
User: "root",
Database: "app",
})
if err != nil {
t.Fatalf("getDSN failed: %v", err)
}
got := parseMySQLDriverCharsetsForTest(t, dsn)
want := []string{"utf8mb4", "utf8"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("driver should parse charset fallback list, got=%v want=%v dsn=%q", got, want, dsn)
}
}
func TestMySQLDSN_URIParamsAndExplicitParamsPrecedence(t *testing.T) {
t.Parallel()

View File

@@ -372,12 +372,46 @@ func buildMySQLCompatibleDSNWithOptions(config connection.ConnectionConfig, prot
mergeMySQLConnectionParams(params, parsed.Query())
}
mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams))
encodedParams := encodeMySQLDSNQuery(params)
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?%s",
config.User, config.Password, protocol, address, database, params.Encode(),
config.User, config.Password, protocol, address, database, encodedParams,
), nil
}
func encodeMySQLDSNQuery(params url.Values) string {
if len(params) == 0 {
return ""
}
keys := make([]string, 0, len(params))
for key := range params {
keys = append(keys, key)
}
sort.Strings(keys)
var builder strings.Builder
for _, key := range keys {
escapedKey := url.QueryEscape(key)
values := params[key]
for _, value := range values {
if builder.Len() > 0 {
builder.WriteByte('&')
}
builder.WriteString(escapedKey)
builder.WriteByte('=')
escapedValue := url.QueryEscape(value)
if strings.EqualFold(strings.TrimSpace(key), "charset") {
escapedValue = strings.ReplaceAll(escapedValue, "%2C", ",")
escapedValue = strings.ReplaceAll(escapedValue, "%2c", ",")
}
builder.WriteString(escapedValue)
}
}
return builder.String()
}
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
defaultMultiStatements := true
return buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{
@@ -475,7 +509,7 @@ func normalizeMySQLRawDSNCompatibilityParams(raw string) string {
if !changed {
return raw
}
encoded := values.Encode()
encoded := encodeMySQLDSNQuery(values)
if encoded == "" {
return prefix + suffix
}