diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index f8a8c3b..8afb619 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -304,18 +304,70 @@ func (k *KingbaseDB) Exec(query string) (int64, error) { } func (k *KingbaseDB) GetDatabases() ([]string, error) { - // Postgres/Kingbase style data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false") + if err == nil { + dbs := collectKingbaseNames(data, "datname", "database") + if len(dbs) > 0 { + return dbs, nil + } + } + + fallbackData, _, fallbackErr := k.Query("SELECT current_database() AS datname") + if fallbackErr != nil { + if err != nil { + return nil, err + } + return nil, fallbackErr + } + + dbs := collectKingbaseNames(fallbackData, "datname", "database", "current_database", "currentDatabase") + if len(dbs) > 0 { + return dbs, nil + } + if err != nil { return nil, err } - var dbs []string - for _, row := range data { - if val, ok := row["datname"]; ok { - dbs = append(dbs, fmt.Sprintf("%v", val)) + return nil, fmt.Errorf("未获取到可见数据库列表") +} + +func collectKingbaseNames(rows []map[string]interface{}, keys ...string) []string { + result := make([]string, 0, len(rows)) + seen := make(map[string]struct{}, len(rows)) + for _, row := range rows { + name := strings.TrimSpace(getKingbaseNameFromRow(row, keys...)) + if name == "" { + continue + } + if _, exists := seen[name]; exists { + continue + } + seen[name] = struct{}{} + result = append(result, name) + } + return result +} + +func getKingbaseNameFromRow(row map[string]interface{}, keys ...string) string { + if len(row) == 0 { + return "" + } + for _, key := range keys { + if value, ok := row[key]; ok { + return fmt.Sprintf("%v", value) } } - return dbs, nil + for existingKey, value := range row { + for _, key := range keys { + if strings.EqualFold(existingKey, key) { + return fmt.Sprintf("%v", value) + } + } + } + for _, value := range row { + return fmt.Sprintf("%v", value) + } + return "" } func (k *KingbaseDB) GetTables(dbName string) ([]string, error) { diff --git a/internal/db/kingbase_impl_test.go b/internal/db/kingbase_impl_test.go index 8b0d6f5..8a171fd 100644 --- a/internal/db/kingbase_impl_test.go +++ b/internal/db/kingbase_impl_test.go @@ -2,7 +2,39 @@ package db -import "testing" +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "strings" + "sync" + "testing" +) + +const fakeKingbaseDriverName = "gonavi-fake-kingbase" + +var ( + registerFakeKingbaseDriverOnce sync.Once + fakeKingbaseStateMu sync.Mutex + fakeKingbaseState = struct { + queryErr error + queryResults map[string]fakeKingbaseQueryResult + lastQuery string + queries []string + }{ + lastQuery: "", + queryResults: map[string]fakeKingbaseQueryResult{}, + queries: nil, + } +) + +type fakeKingbaseQueryResult struct { + columns []string + rows [][]driver.Value + err error +} func TestNormalizeKingbaseIdentifier(t *testing.T) { tests := []struct { @@ -115,3 +147,125 @@ func TestSplitKingbaseQualifiedTable(t *testing.T) { }) } } + +func TestKingbaseGetDatabasesFallsBackToCurrentDatabase(t *testing.T) { + registerFakeKingbaseDriverOnce.Do(func() { + sql.Register(fakeKingbaseDriverName, fakeKingbaseDriver{}) + }) + + db, err := sql.Open(fakeKingbaseDriverName, "") + if err != nil { + t.Fatalf("open fake kingbase db failed: %v", err) + } + defer db.Close() + + const listSQL = "SELECT datname FROM pg_database WHERE datistemplate = false" + const fallbackSQL = "SELECT current_database() AS datname" + + fakeKingbaseStateMu.Lock() + fakeKingbaseState.queryErr = nil + fakeKingbaseState.queryResults = map[string]fakeKingbaseQueryResult{ + listSQL: { + err: errors.New("permission denied for relation pg_database"), + }, + fallbackSQL: { + columns: []string{"datname"}, + rows: [][]driver.Value{ + {"demo"}, + }, + }, + } + fakeKingbaseState.lastQuery = "" + fakeKingbaseState.queries = nil + fakeKingbaseStateMu.Unlock() + + client := &KingbaseDB{conn: db} + databases, err := client.GetDatabases() + if err != nil { + t.Fatalf("expected GetDatabases to fallback, got err=%v", err) + } + if len(databases) != 1 || databases[0] != "demo" { + t.Fatalf("expected fallback database list, got %v", databases) + } + + fakeKingbaseStateMu.Lock() + queries := append([]string(nil), fakeKingbaseState.queries...) + fakeKingbaseStateMu.Unlock() + if len(queries) != 2 { + t.Fatalf("expected two queries, got %v", queries) + } + if queries[0] != listSQL || queries[1] != fallbackSQL { + t.Fatalf("unexpected query order: %v", queries) + } +} + +type fakeKingbaseDriver struct{} + +func (fakeKingbaseDriver) Open(name string) (driver.Conn, error) { + return fakeKingbaseConn{}, nil +} + +type fakeKingbaseConn struct{} + +func (fakeKingbaseConn) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("prepare not implemented") +} + +func (fakeKingbaseConn) Close() error { + return nil +} + +func (fakeKingbaseConn) Begin() (driver.Tx, error) { + return nil, errors.New("transactions not implemented") +} + +func (fakeKingbaseConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + fakeKingbaseStateMu.Lock() + defer fakeKingbaseStateMu.Unlock() + fakeKingbaseState.lastQuery = query + fakeKingbaseState.queries = append(fakeKingbaseState.queries, query) + if result, ok := fakeKingbaseState.queryResults[query]; ok { + if result.err != nil { + return nil, result.err + } + return &fakeKingbaseRows{columns: result.columns, rows: result.rows}, nil + } + if fakeKingbaseState.queryErr != nil { + return nil, fakeKingbaseState.queryErr + } + return &fakeKingbaseRows{}, nil +} + +type fakeKingbaseRows struct { + columns []string + rows [][]driver.Value + index int +} + +func (r *fakeKingbaseRows) Columns() []string { + if len(r.columns) > 0 { + return r.columns + } + return []string{"datname"} +} + +func (r *fakeKingbaseRows) Close() error { + return nil +} + +func (r *fakeKingbaseRows) Next(dest []driver.Value) error { + if r.index < len(r.rows) { + row := r.rows[r.index] + for idx := range dest { + if idx < len(row) { + dest[idx] = row[idx] + } + } + r.index++ + return nil + } + if len(dest) > 0 { + dest[0] = strings.TrimSpace("demo") + } + return io.EOF +}