From 5d86ee7c76abfaba55d497d860668cfcfa5e70da Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sat, 11 Apr 2026 21:53:51 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(clickhouse):=20=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E6=95=B0=E6=8D=AE=E5=BA=93=E5=88=97=E8=A1=A8=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5=E6=97=B6=E5=9B=9E=E9=80=80=E5=BD=93=E5=89=8D=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #308 --- internal/db/clickhouse_impl.go | 52 ++++++++++++--- internal/db/clickhouse_impl_test.go | 98 +++++++++++++++++++++++++++-- 2 files changed, 137 insertions(+), 13 deletions(-) diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index c6be5d0..aad3011 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -367,22 +367,58 @@ func (c *ClickHouseDB) Exec(query string) (int64, error) { func (c *ClickHouseDB) GetDatabases() ([]string, error) { data, _, err := c.Query("SELECT name FROM system.databases ORDER BY name") - if err != nil { - return nil, err + if err == nil { + result := make([]string, 0, len(data)) + for _, row := range data { + if val, ok := getClickHouseValueFromRow(row, "name", "database"); ok { + result = append(result, fmt.Sprintf("%v", val)) + continue + } + for _, value := range row { + result = append(result, fmt.Sprintf("%v", value)) + break + } + } + if len(result) > 0 { + return result, nil + } } - result := make([]string, 0, len(data)) - for _, row := range data { - if val, ok := getClickHouseValueFromRow(row, "name", "database"); ok { - result = append(result, fmt.Sprintf("%v", val)) + fallbackData, _, fallbackErr := c.Query("SELECT currentDatabase() AS name") + if fallbackErr != nil { + if err != nil { + return nil, err + } + return nil, fallbackErr + } + + result := make([]string, 0, len(fallbackData)) + for _, row := range fallbackData { + if val, ok := getClickHouseValueFromRow(row, "name", "database", "currentDatabase"); ok { + name := strings.TrimSpace(fmt.Sprintf("%v", val)) + if name != "" { + result = append(result, name) + } continue } for _, value := range row { - result = append(result, fmt.Sprintf("%v", value)) + name := strings.TrimSpace(fmt.Sprintf("%v", value)) + if name != "" { + result = append(result, name) + } break } } - return result, nil + if len(result) > 0 { + return result, nil + } + if current := strings.TrimSpace(c.database); current != "" { + return []string{current}, nil + } + if err != nil { + return nil, err + } + return nil, fmt.Errorf("未获取到 ClickHouse 数据库列表") } func (c *ClickHouseDB) GetTables(dbName string) ([]string, error) { diff --git a/internal/db/clickhouse_impl_test.go b/internal/db/clickhouse_impl_test.go index a2fd40a..cd7c7b5 100644 --- a/internal/db/clickhouse_impl_test.go +++ b/internal/db/clickhouse_impl_test.go @@ -20,14 +20,24 @@ var ( registerFakeClickHouseDriverOnce sync.Once fakeClickHouseStateMu sync.Mutex fakeClickHouseState = struct { - pingErr error - queryErr error - lastQuery string + pingErr error + queryErr error + queryResults map[string]fakeClickHouseQueryResult + lastQuery string + queries []string }{ - lastQuery: "", + lastQuery: "", + queryResults: map[string]fakeClickHouseQueryResult{}, + queries: nil, } ) +type fakeClickHouseQueryResult struct { + columns []string + rows [][]driver.Value + err error +} + func TestClickHousePingValidatesQueryPath(t *testing.T) { registerFakeClickHouseDriverOnce.Do(func() { sql.Register(fakeClickHouseDriverName, fakeClickHouseDriver{}) @@ -42,7 +52,9 @@ func TestClickHousePingValidatesQueryPath(t *testing.T) { fakeClickHouseStateMu.Lock() fakeClickHouseState.pingErr = nil fakeClickHouseState.queryErr = errors.New("query path failed") + fakeClickHouseState.queryResults = map[string]fakeClickHouseQueryResult{} fakeClickHouseState.lastQuery = "" + fakeClickHouseState.queries = nil fakeClickHouseStateMu.Unlock() client := &ClickHouseDB{ @@ -65,6 +77,58 @@ func TestClickHousePingValidatesQueryPath(t *testing.T) { } } +func TestClickHouseGetDatabasesFallsBackToCurrentDatabase(t *testing.T) { + registerFakeClickHouseDriverOnce.Do(func() { + sql.Register(fakeClickHouseDriverName, fakeClickHouseDriver{}) + }) + + db, err := sql.Open(fakeClickHouseDriverName, "") + if err != nil { + t.Fatalf("open fake clickhouse db failed: %v", err) + } + defer db.Close() + + const listSQL = "SELECT name FROM system.databases ORDER BY name" + const fallbackSQL = "SELECT currentDatabase() AS name" + + fakeClickHouseStateMu.Lock() + fakeClickHouseState.pingErr = nil + fakeClickHouseState.queryErr = nil + fakeClickHouseState.queryResults = map[string]fakeClickHouseQueryResult{ + listSQL: { + err: errors.New("access denied to system.databases"), + }, + fallbackSQL: { + columns: []string{"name"}, + rows: [][]driver.Value{ + {"analytics"}, + }, + }, + } + fakeClickHouseState.lastQuery = "" + fakeClickHouseState.queries = nil + fakeClickHouseStateMu.Unlock() + + client := &ClickHouseDB{conn: db} + databases, err := client.GetDatabases() + if err != nil { + t.Fatalf("expected GetDatabases to fallback, got err=%v", err) + } + if len(databases) != 1 || databases[0] != "analytics" { + t.Fatalf("expected fallback database list, got %v", databases) + } + + fakeClickHouseStateMu.Lock() + queries := append([]string(nil), fakeClickHouseState.queries...) + fakeClickHouseStateMu.Unlock() + if len(queries) != 2 { + t.Fatalf("expected two queries, got %v", queries) + } + if queries[0] != listSQL || queries[1] != fallbackSQL { + t.Fatalf("unexpected query order: %v", queries) + } +} + type fakeClickHouseDriver struct{} func (fakeClickHouseDriver) Open(name string) (driver.Conn, error) { @@ -95,15 +159,29 @@ func (fakeClickHouseConn) QueryContext(ctx context.Context, query string, args [ fakeClickHouseStateMu.Lock() defer fakeClickHouseStateMu.Unlock() fakeClickHouseState.lastQuery = query + fakeClickHouseState.queries = append(fakeClickHouseState.queries, query) + if result, ok := fakeClickHouseState.queryResults[query]; ok { + if result.err != nil { + return nil, result.err + } + return &fakeClickHouseRows{columns: result.columns, rows: result.rows}, nil + } if fakeClickHouseState.queryErr != nil { return nil, fakeClickHouseState.queryErr } return &fakeClickHouseRows{}, nil } -type fakeClickHouseRows struct{} +type fakeClickHouseRows struct { + columns []string + rows [][]driver.Value + index int +} func (r *fakeClickHouseRows) Columns() []string { + if len(r.columns) > 0 { + return r.columns + } return []string{"currentDatabase"} } @@ -112,6 +190,16 @@ func (r *fakeClickHouseRows) Close() error { } func (r *fakeClickHouseRows) Next(dest []driver.Value) error { + if r.index < len(r.rows) { + row := r.rows[r.index] + for idx := range dest { + if idx < len(row) { + dest[idx] = row[idx] + } + } + r.index++ + return nil + } if len(dest) > 0 { dest[0] = "default" }