diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index b0ec9f5..1cf6684 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -33,7 +33,7 @@ type sqlServerSessionExecer struct { func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg *sqlexp.ReturnMessage) ([]connection.ResultSetData, []string, error) { if rows == nil { - return []connection.ResultSetData{{Rows: []map[string]interface{}{}, Columns: []string{}}}, nil, nil + return []connection.ResultSetData{emptySQLServerRowsResultSet()}, nil, nil } if ctx == nil { ctx = context.Background() @@ -95,10 +95,11 @@ func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg * }) } if len(resultSets) == 0 { - resultSets = []connection.ResultSetData{{ - Rows: []map[string]interface{}{}, - Columns: []string{}, - }} + fallbackResult, err := scanSQLServerFallbackResultSet(rows) + if err != nil { + return resultSets, allMessages, err + } + resultSets = []connection.ResultSetData{fallbackResult} } if err := rows.Err(); err != nil { return resultSets, allMessages, err @@ -106,6 +107,30 @@ func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg * return resultSets, allMessages, nil } +func emptySQLServerRowsResultSet() connection.ResultSetData { + return connection.ResultSetData{ + Rows: []map[string]interface{}{}, + Columns: []string{}, + } +} + +func scanSQLServerFallbackResultSet(rows *sql.Rows) (connection.ResultSetData, error) { + data, columns, err := scanRows(rows) + if err != nil { + return emptySQLServerRowsResultSet(), err + } + if data == nil { + data = []map[string]interface{}{} + } + if columns == nil { + columns = []string{} + } + return connection.ResultSetData{ + Rows: data, + Columns: columns, + }, nil +} + // quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation func quoteBracket(name string) string { return strings.ReplaceAll(name, "]", "]]") diff --git a/internal/db/sqlserver_impl_test.go b/internal/db/sqlserver_impl_test.go index 3d62ce4..ac31a63 100644 --- a/internal/db/sqlserver_impl_test.go +++ b/internal/db/sqlserver_impl_test.go @@ -3,12 +3,16 @@ package db import ( + "database/sql" "errors" "os" + "reflect" "strings" "testing" "GoNavi-Wails/shared/i18n" + + _ "modernc.org/sqlite" ) var rawSQLServerTableNameRequiredText = string([]rune{0x8868, 0x540d, 0x4e0d, 0x80fd, 0x4e3a, 0x7a7a}) @@ -72,6 +76,60 @@ func TestSQLServerRowsAffectedDoesNotHideDMLRowsAffectedErrors(t *testing.T) { } } +func TestScanSQLServerFallbackResultSetPreservesRowsWhenMessageLoopYieldsNoResult(t *testing.T) { + dbConn, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { + _ = dbConn.Close() + }) + + rows, err := dbConn.Query("SELECT 'config:roomType:add' AS menuName") + if err != nil { + t.Fatalf("query rows: %v", err) + } + defer rows.Close() + + resultSet, err := scanSQLServerFallbackResultSet(rows) + if err != nil { + t.Fatalf("scanSQLServerFallbackResultSet returned error: %v", err) + } + if !reflect.DeepEqual(resultSet.Columns, []string{"menuName"}) { + t.Fatalf("expected SELECT columns to be preserved, got %#v", resultSet.Columns) + } + if len(resultSet.Rows) != 1 || resultSet.Rows[0]["menuName"] != "config:roomType:add" { + t.Fatalf("expected SELECT rows to be preserved, got %#v", resultSet.Rows) + } +} + +func TestScanSQLServerFallbackResultSetPreservesColumnsWhenResultHasNoRows(t *testing.T) { + dbConn, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { + _ = dbConn.Close() + }) + + rows, err := dbConn.Query("SELECT 1 AS menuName WHERE 1 = 0") + if err != nil { + t.Fatalf("query empty rows: %v", err) + } + defer rows.Close() + + resultSet, err := scanSQLServerFallbackResultSet(rows) + if err != nil { + t.Fatalf("scanSQLServerFallbackResultSet returned error: %v", err) + } + if len(resultSet.Rows) != 0 { + t.Fatalf("expected empty rows, got %#v", resultSet.Rows) + } + if !reflect.DeepEqual(resultSet.Columns, []string{"menuName"}) { + t.Fatalf("expected empty SELECT columns to be preserved, got %#v", resultSet.Columns) + } +} + func TestSQLServerMetadataErrorsUseCurrentLanguage(t *testing.T) { SetBackendLanguage(i18n.LanguageEnUS) t.Cleanup(func() {