From 83fe3d4ed98cc28a7abe3d5f150078d264447e60 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sat, 11 Apr 2026 21:53:51 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(driver):=20=E6=8F=90?= =?UTF-8?q?=E5=8D=87=E6=89=B9=E9=87=8F=20INSERT=20=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E6=95=88=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #311 --- internal/app/methods_db_multi_test.go | 128 ++++++++++++++++++++++ internal/db/batch_write_interface_test.go | 35 ++++++ internal/db/highgo_impl.go | 11 ++ internal/db/kingbase_impl.go | 11 ++ internal/db/sqlserver_impl.go | 11 ++ internal/db/vastbase_impl.go | 11 ++ 6 files changed, 207 insertions(+) create mode 100644 internal/app/methods_db_multi_test.go create mode 100644 internal/db/batch_write_interface_test.go diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go new file mode 100644 index 0000000..1f0af6d --- /dev/null +++ b/internal/app/methods_db_multi_test.go @@ -0,0 +1,128 @@ +package app + +import ( + "context" + "testing" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "GoNavi-Wails/internal/secretstore" +) + +type fakeBatchWriteDB struct { + batchCalls int + execCalls int + lastQuery string +} + +func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error { + return nil +} + +func (f *fakeBatchWriteDB) Close() error { + return nil +} + +func (f *fakeBatchWriteDB) Ping() error { + return nil +} + +func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []string, error) { + return nil, nil, nil +} + +func (f *fakeBatchWriteDB) Exec(query string) (int64, error) { + f.execCalls++ + return 1, nil +} + +func (f *fakeBatchWriteDB) GetDatabases() ([]string, error) { + return nil, nil +} + +func (f *fakeBatchWriteDB) GetTables(dbName string) ([]string, error) { + return nil, nil +} + +func (f *fakeBatchWriteDB) GetCreateStatement(dbName, tableName string) (string, error) { + return "", nil +} + +func (f *fakeBatchWriteDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + return nil, nil +} + +func (f *fakeBatchWriteDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + return nil, nil +} + +func (f *fakeBatchWriteDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + return nil, nil +} + +func (f *fakeBatchWriteDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + return nil, nil +} + +func (f *fakeBatchWriteDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + return nil, nil +} + +func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64, error) { + f.execCalls++ + return 1, nil +} + +func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + f.batchCalls++ + f.lastQuery = query + return 500, nil +} + +var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil) + +func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + fakeDB := &fakeBatchWriteDB{} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{ + Type: "mysql", + Host: "127.0.0.1", + Port: 1433, + User: "sa", + } + query := "INSERT INTO demo(id) VALUES (1);\nINSERT INTO demo(id) VALUES (2);" + + result := app.DBQueryMulti(config, "testdb", query, "batch-write-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.batchCalls != 1 { + t.Fatalf("expected batch path to run once, got %d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 0 { + t.Fatalf("expected sequential exec path to be skipped, got execCalls=%d", fakeDB.execCalls) + } + if fakeDB.lastQuery != query { + t.Fatalf("expected batch query to stay intact, got %q", fakeDB.lastQuery) + } + + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 { + t.Fatalf("expected one affectedRows result set, got %#v", resultSets) + } + if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(500) { + t.Fatalf("expected affectedRows=500, got %#v", got) + } +} diff --git a/internal/db/batch_write_interface_test.go b/internal/db/batch_write_interface_test.go new file mode 100644 index 0000000..1ec3ccf --- /dev/null +++ b/internal/db/batch_write_interface_test.go @@ -0,0 +1,35 @@ +//go:build gonavi_full_drivers || gonavi_sqlserver_driver || gonavi_kingbase_driver || gonavi_highgo_driver || gonavi_vastbase_driver + +package db + +import "testing" + +func TestBatchWriteDriverCoverage(t *testing.T) { + t.Run("sqlserver", func(t *testing.T) { + var driver BatchWriteExecer = (*SqlServerDB)(nil) + if driver == nil { + t.Fatal("expected SqlServerDB to implement BatchWriteExecer") + } + }) + + t.Run("kingbase", func(t *testing.T) { + var driver BatchWriteExecer = (*KingbaseDB)(nil) + if driver == nil { + t.Fatal("expected KingbaseDB to implement BatchWriteExecer") + } + }) + + t.Run("highgo", func(t *testing.T) { + var driver BatchWriteExecer = (*HighGoDB)(nil) + if driver == nil { + t.Fatal("expected HighGoDB to implement BatchWriteExecer") + } + }) + + t.Run("vastbase", func(t *testing.T) { + var driver BatchWriteExecer = (*VastbaseDB)(nil) + if driver == nil { + t.Fatal("expected VastbaseDB to implement BatchWriteExecer") + } + }) +} diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index a07027f..87c9f4b 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -174,6 +174,17 @@ func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error) return res.RowsAffected() } +func (h *HighGoDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + if h.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := h.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (h *HighGoDB) Exec(query string) (int64, error) { if h.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 9a913d1..f8a8c3b 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -281,6 +281,17 @@ func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, erro return res.RowsAffected() } +func (k *KingbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + if k.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := k.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (k *KingbaseDB) Exec(query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index 6b6996e..ed06fd5 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -190,6 +190,17 @@ func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, err return res.RowsAffected() } +func (s *SqlServerDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + if s.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := s.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (s *SqlServerDB) Exec(query string) (int64, error) { if s.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go index 8c6a4ed..bb23713 100644 --- a/internal/db/vastbase_impl.go +++ b/internal/db/vastbase_impl.go @@ -173,6 +173,17 @@ func (v *VastbaseDB) ExecContext(ctx context.Context, query string) (int64, erro return res.RowsAffected() } +func (v *VastbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + if v.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := v.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (v *VastbaseDB) Exec(query string) (int64, error) { if v.conn == nil { return 0, fmt.Errorf("连接未打开")