🐛 fix(clickhouse): 获取数据库列表失败时回退当前库

Fixes #308
This commit is contained in:
Syngnat
2026-04-11 21:53:51 +08:00
parent 8297829be6
commit 5d86ee7c76
2 changed files with 137 additions and 13 deletions

View File

@@ -367,22 +367,58 @@ func (c *ClickHouseDB) Exec(query string) (int64, error) {
func (c *ClickHouseDB) GetDatabases() ([]string, error) {
data, _, err := c.Query("SELECT name FROM system.databases ORDER BY name")
if err != nil {
return nil, err
if err == nil {
result := make([]string, 0, len(data))
for _, row := range data {
if val, ok := getClickHouseValueFromRow(row, "name", "database"); ok {
result = append(result, fmt.Sprintf("%v", val))
continue
}
for _, value := range row {
result = append(result, fmt.Sprintf("%v", value))
break
}
}
if len(result) > 0 {
return result, nil
}
}
result := make([]string, 0, len(data))
for _, row := range data {
if val, ok := getClickHouseValueFromRow(row, "name", "database"); ok {
result = append(result, fmt.Sprintf("%v", val))
fallbackData, _, fallbackErr := c.Query("SELECT currentDatabase() AS name")
if fallbackErr != nil {
if err != nil {
return nil, err
}
return nil, fallbackErr
}
result := make([]string, 0, len(fallbackData))
for _, row := range fallbackData {
if val, ok := getClickHouseValueFromRow(row, "name", "database", "currentDatabase"); ok {
name := strings.TrimSpace(fmt.Sprintf("%v", val))
if name != "" {
result = append(result, name)
}
continue
}
for _, value := range row {
result = append(result, fmt.Sprintf("%v", value))
name := strings.TrimSpace(fmt.Sprintf("%v", value))
if name != "" {
result = append(result, name)
}
break
}
}
return result, nil
if len(result) > 0 {
return result, nil
}
if current := strings.TrimSpace(c.database); current != "" {
return []string{current}, nil
}
if err != nil {
return nil, err
}
return nil, fmt.Errorf("未获取到 ClickHouse 数据库列表")
}
func (c *ClickHouseDB) GetTables(dbName string) ([]string, error) {

View File

@@ -20,14 +20,24 @@ var (
registerFakeClickHouseDriverOnce sync.Once
fakeClickHouseStateMu sync.Mutex
fakeClickHouseState = struct {
pingErr error
queryErr error
lastQuery string
pingErr error
queryErr error
queryResults map[string]fakeClickHouseQueryResult
lastQuery string
queries []string
}{
lastQuery: "",
lastQuery: "",
queryResults: map[string]fakeClickHouseQueryResult{},
queries: nil,
}
)
type fakeClickHouseQueryResult struct {
columns []string
rows [][]driver.Value
err error
}
func TestClickHousePingValidatesQueryPath(t *testing.T) {
registerFakeClickHouseDriverOnce.Do(func() {
sql.Register(fakeClickHouseDriverName, fakeClickHouseDriver{})
@@ -42,7 +52,9 @@ func TestClickHousePingValidatesQueryPath(t *testing.T) {
fakeClickHouseStateMu.Lock()
fakeClickHouseState.pingErr = nil
fakeClickHouseState.queryErr = errors.New("query path failed")
fakeClickHouseState.queryResults = map[string]fakeClickHouseQueryResult{}
fakeClickHouseState.lastQuery = ""
fakeClickHouseState.queries = nil
fakeClickHouseStateMu.Unlock()
client := &ClickHouseDB{
@@ -65,6 +77,58 @@ func TestClickHousePingValidatesQueryPath(t *testing.T) {
}
}
func TestClickHouseGetDatabasesFallsBackToCurrentDatabase(t *testing.T) {
registerFakeClickHouseDriverOnce.Do(func() {
sql.Register(fakeClickHouseDriverName, fakeClickHouseDriver{})
})
db, err := sql.Open(fakeClickHouseDriverName, "")
if err != nil {
t.Fatalf("open fake clickhouse db failed: %v", err)
}
defer db.Close()
const listSQL = "SELECT name FROM system.databases ORDER BY name"
const fallbackSQL = "SELECT currentDatabase() AS name"
fakeClickHouseStateMu.Lock()
fakeClickHouseState.pingErr = nil
fakeClickHouseState.queryErr = nil
fakeClickHouseState.queryResults = map[string]fakeClickHouseQueryResult{
listSQL: {
err: errors.New("access denied to system.databases"),
},
fallbackSQL: {
columns: []string{"name"},
rows: [][]driver.Value{
{"analytics"},
},
},
}
fakeClickHouseState.lastQuery = ""
fakeClickHouseState.queries = nil
fakeClickHouseStateMu.Unlock()
client := &ClickHouseDB{conn: db}
databases, err := client.GetDatabases()
if err != nil {
t.Fatalf("expected GetDatabases to fallback, got err=%v", err)
}
if len(databases) != 1 || databases[0] != "analytics" {
t.Fatalf("expected fallback database list, got %v", databases)
}
fakeClickHouseStateMu.Lock()
queries := append([]string(nil), fakeClickHouseState.queries...)
fakeClickHouseStateMu.Unlock()
if len(queries) != 2 {
t.Fatalf("expected two queries, got %v", queries)
}
if queries[0] != listSQL || queries[1] != fallbackSQL {
t.Fatalf("unexpected query order: %v", queries)
}
}
type fakeClickHouseDriver struct{}
func (fakeClickHouseDriver) Open(name string) (driver.Conn, error) {
@@ -95,15 +159,29 @@ func (fakeClickHouseConn) QueryContext(ctx context.Context, query string, args [
fakeClickHouseStateMu.Lock()
defer fakeClickHouseStateMu.Unlock()
fakeClickHouseState.lastQuery = query
fakeClickHouseState.queries = append(fakeClickHouseState.queries, query)
if result, ok := fakeClickHouseState.queryResults[query]; ok {
if result.err != nil {
return nil, result.err
}
return &fakeClickHouseRows{columns: result.columns, rows: result.rows}, nil
}
if fakeClickHouseState.queryErr != nil {
return nil, fakeClickHouseState.queryErr
}
return &fakeClickHouseRows{}, nil
}
type fakeClickHouseRows struct{}
type fakeClickHouseRows struct {
columns []string
rows [][]driver.Value
index int
}
func (r *fakeClickHouseRows) Columns() []string {
if len(r.columns) > 0 {
return r.columns
}
return []string{"currentDatabase"}
}
@@ -112,6 +190,16 @@ func (r *fakeClickHouseRows) Close() error {
}
func (r *fakeClickHouseRows) Next(dest []driver.Value) error {
if r.index < len(r.rows) {
row := r.rows[r.index]
for idx := range dest {
if idx < len(row) {
dest[idx] = row[idx]
}
}
r.index++
return nil
}
if len(dest) > 0 {
dest[0] = "default"
}