mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 10:29:52 +08:00
🐛 fix(db): 保留多写语句结果并修复 MySQL 字符集参数
- 多条写语句改为逐条返回 affectedRows,避免只显示最后一条结果 - 为写语句结果补充 statementIndex,保持语句与结果映射 - 保留 MySQL charset fallback 逗号,避免驱动解析成 %2C
This commit is contained in:
@@ -827,8 +827,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
defer closeExecTarget()
|
||||
|
||||
// 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返
|
||||
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
|
||||
// 单条写语句且驱动支持批量 Exec 时,可复用批量路径。
|
||||
// 多条写语句必须逐条返回结果;部分驱动对多语句 Exec 仅暴露最后一条 RowsAffected,
|
||||
// 会导致前面语句已成功执行但结果页只剩一个写入结果。
|
||||
if !allReadOnly {
|
||||
allWrite := true
|
||||
containsPLSQLBlock := false
|
||||
@@ -840,7 +841,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
containsPLSQLBlock = true
|
||||
}
|
||||
}
|
||||
if allWrite && !containsPLSQLBlock {
|
||||
if allWrite && !containsPLSQLBlock && len(statements) == 1 {
|
||||
batcher := sessionBatchTarget
|
||||
if batcher == nil {
|
||||
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
|
||||
@@ -987,8 +988,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
StatementIndex: idx + 1,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -10,18 +10,19 @@ import (
|
||||
)
|
||||
|
||||
type fakeBatchWriteDB struct {
|
||||
batchCalls int
|
||||
execCalls int
|
||||
execQueries []string
|
||||
lastQuery string
|
||||
lastCtx context.Context
|
||||
queryCalls int
|
||||
queryMap map[string][]map[string]interface{}
|
||||
fieldMap map[string][]string
|
||||
messageMap map[string][]string
|
||||
multiResult map[string][]connection.ResultSetData
|
||||
queryErr map[string]error
|
||||
session *fakeBatchWriteSession
|
||||
batchCalls int
|
||||
execCalls int
|
||||
execQueries []string
|
||||
lastQuery string
|
||||
lastCtx context.Context
|
||||
queryCalls int
|
||||
queryMap map[string][]map[string]interface{}
|
||||
fieldMap map[string][]string
|
||||
messageMap map[string][]string
|
||||
multiResult map[string][]connection.ResultSetData
|
||||
queryErr map[string]error
|
||||
execAffected map[string]int64
|
||||
session *fakeBatchWriteSession
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error {
|
||||
@@ -52,6 +53,9 @@ func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interfa
|
||||
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
|
||||
f.execCalls++
|
||||
f.execQueries = append(f.execQueries, query)
|
||||
if affected, ok := f.execAffected[query]; ok {
|
||||
return affected, nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
@@ -91,6 +95,9 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64
|
||||
f.lastCtx = ctx
|
||||
f.execCalls++
|
||||
f.execQueries = append(f.execQueries, query)
|
||||
if affected, ok := f.execAffected[query]; ok {
|
||||
return affected, nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
@@ -440,13 +447,20 @@ func TestDBQueryWithCancel_DuckDBQueriesDoNotInheritConnectTimeout(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
|
||||
func TestDBQueryMultiPreservesPerStatementResultsForMultipleWriteStatements(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
fakeDB := &fakeBatchWriteDB{}
|
||||
firstStmt := "DELETE FROM assets_asset"
|
||||
secondStmt := "DELETE FROM assets_assetcategory"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
execAffected: map[string]int64{
|
||||
firstStmt: 5,
|
||||
secondStmt: 10,
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
@@ -458,31 +472,37 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
|
||||
Port: 1433,
|
||||
User: "sa",
|
||||
}
|
||||
query := "INSERT INTO demo(id) VALUES (1);\nINSERT INTO demo(id) VALUES (2);"
|
||||
query := firstStmt + ";\n" + secondStmt + ";"
|
||||
|
||||
result := app.DBQueryMulti(config, "testdb", query, "batch-write-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.batchCalls != 1 {
|
||||
t.Fatalf("expected batch path to run once, got %d", fakeDB.batchCalls)
|
||||
if fakeDB.batchCalls != 0 {
|
||||
t.Fatalf("expected multiple write statements to skip batch path so each result can be preserved, got %d", fakeDB.batchCalls)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected sequential exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
if fakeDB.execCalls != 2 {
|
||||
t.Fatalf("expected sequential exec path to run twice, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
if fakeDB.lastQuery != query {
|
||||
t.Fatalf("expected batch query to stay intact, got %q", fakeDB.lastQuery)
|
||||
if len(fakeDB.execQueries) != 2 || fakeDB.execQueries[0] != firstStmt || fakeDB.execQueries[1] != secondStmt {
|
||||
t.Fatalf("expected sequential execs to preserve statement order, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
|
||||
t.Fatalf("expected one affectedRows result set, got %#v", resultSets)
|
||||
if len(resultSets) != 2 {
|
||||
t.Fatalf("expected one affectedRows result set per statement, got %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(500) {
|
||||
t.Fatalf("expected affectedRows=500, got %#v", got)
|
||||
if len(resultSets[0].Rows) != 1 || len(resultSets[1].Rows) != 1 {
|
||||
t.Fatalf("expected both result sets to contain a single affectedRows row, got %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(5) {
|
||||
t.Fatalf("expected first affectedRows=5, got %#v", got)
|
||||
}
|
||||
if got := resultSets[1].Rows[0]["affectedRows"]; got != int64(10) {
|
||||
t.Fatalf("expected second affectedRows=10, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -111,6 +111,18 @@ func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesExplicitMultiStatements
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMySQLRawDSNCompatibilityParamsPreservesCharsetFallbackComma(t *testing.T) {
|
||||
got := normalizeMySQLRawDSNCompatibilityParams(
|
||||
"root:pass@tcp(127.0.0.1:3306)/app?charset=utf8mb4,utf8&allowMultiQueries=true",
|
||||
)
|
||||
if strings.Contains(got, "%2C") || strings.Contains(got, "%2c") {
|
||||
t.Fatalf("charset fallback comma should stay unescaped for mysql driver, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "charset=utf8mb4,utf8") {
|
||||
t.Fatalf("charset fallback list should be preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomDBOnlyNormalizesBuiltInMySQLDriverDSN(t *testing.T) {
|
||||
customMySQLDSNRecordingLastDSN = ""
|
||||
rawDSN := "root:pass@tcp(127.0.0.1:3306)/app?allowMultiQueries=true"
|
||||
|
||||
@@ -3,10 +3,13 @@ package db
|
||||
import (
|
||||
"database/sql"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
|
||||
mysql "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values {
|
||||
@@ -22,6 +25,26 @@ func parseMySQLDSNQueryForTest(t *testing.T, dsn string) url.Values {
|
||||
return values
|
||||
}
|
||||
|
||||
func parseMySQLDriverCharsetsForTest(t *testing.T, dsn string) []string {
|
||||
t.Helper()
|
||||
|
||||
cfg, err := mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("mysql ParseDSN failed: %v", err)
|
||||
}
|
||||
|
||||
field := reflect.ValueOf(cfg).Elem().FieldByName("charsets")
|
||||
if !field.IsValid() {
|
||||
t.Fatal("mysql.Config missing internal charsets field")
|
||||
}
|
||||
|
||||
charsets := make([]string, field.Len())
|
||||
for i := 0; i < field.Len(); i++ {
|
||||
charsets[i] = field.Index(i).String()
|
||||
}
|
||||
return charsets
|
||||
}
|
||||
|
||||
func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -398,6 +421,27 @@ func TestMySQLDSN_AsiaShanghaiLocationAcceptedByDriver(t *testing.T) {
|
||||
_ = db.Close()
|
||||
}
|
||||
|
||||
func TestMySQLDSN_DefaultCharsetFallbackListRemainsDriverCompatible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := &MySQLDB{}
|
||||
dsn, err := m.getDSN(connection.ConnectionConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Database: "app",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("getDSN failed: %v", err)
|
||||
}
|
||||
|
||||
got := parseMySQLDriverCharsetsForTest(t, dsn)
|
||||
want := []string{"utf8mb4", "utf8"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("driver should parse charset fallback list, got=%v want=%v dsn=%q", got, want, dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_URIParamsAndExplicitParamsPrecedence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -372,12 +372,46 @@ func buildMySQLCompatibleDSNWithOptions(config connection.ConnectionConfig, prot
|
||||
mergeMySQLConnectionParams(params, parsed.Query())
|
||||
}
|
||||
mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams))
|
||||
encodedParams := encodeMySQLDSNQuery(params)
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@%s(%s)/%s?%s",
|
||||
config.User, config.Password, protocol, address, database, params.Encode(),
|
||||
config.User, config.Password, protocol, address, database, encodedParams,
|
||||
), nil
|
||||
}
|
||||
|
||||
func encodeMySQLDSNQuery(params url.Values) string {
|
||||
if len(params) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(params))
|
||||
for key := range params {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var builder strings.Builder
|
||||
for _, key := range keys {
|
||||
escapedKey := url.QueryEscape(key)
|
||||
values := params[key]
|
||||
for _, value := range values {
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteByte('&')
|
||||
}
|
||||
builder.WriteString(escapedKey)
|
||||
builder.WriteByte('=')
|
||||
escapedValue := url.QueryEscape(value)
|
||||
if strings.EqualFold(strings.TrimSpace(key), "charset") {
|
||||
escapedValue = strings.ReplaceAll(escapedValue, "%2C", ",")
|
||||
escapedValue = strings.ReplaceAll(escapedValue, "%2c", ",")
|
||||
}
|
||||
builder.WriteString(escapedValue)
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) (string, error) {
|
||||
defaultMultiStatements := true
|
||||
return buildMySQLCompatibleDSNWithOptions(config, protocol, address, database, mySQLCompatibleDSNOptions{
|
||||
@@ -475,7 +509,7 @@ func normalizeMySQLRawDSNCompatibilityParams(raw string) string {
|
||||
if !changed {
|
||||
return raw
|
||||
}
|
||||
encoded := values.Encode()
|
||||
encoded := encodeMySQLDSNQuery(values)
|
||||
if encoded == "" {
|
||||
return prefix + suffix
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user