🐛 fix(kingbase): 回退当前数据库元数据查询

Fixes #316
This commit is contained in:
Syngnat
2026-04-11 21:53:52 +08:00
parent 5038ae5c9b
commit aa1bb5b886
2 changed files with 213 additions and 7 deletions

View File

@@ -304,18 +304,70 @@ func (k *KingbaseDB) Exec(query string) (int64, error) {
}
func (k *KingbaseDB) GetDatabases() ([]string, error) {
// Postgres/Kingbase style
data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
if err == nil {
dbs := collectKingbaseNames(data, "datname", "database")
if len(dbs) > 0 {
return dbs, nil
}
}
fallbackData, _, fallbackErr := k.Query("SELECT current_database() AS datname")
if fallbackErr != nil {
if err != nil {
return nil, err
}
return nil, fallbackErr
}
dbs := collectKingbaseNames(fallbackData, "datname", "database", "current_database", "currentDatabase")
if len(dbs) > 0 {
return dbs, nil
}
if err != nil {
return nil, err
}
var dbs []string
for _, row := range data {
if val, ok := row["datname"]; ok {
dbs = append(dbs, fmt.Sprintf("%v", val))
return nil, fmt.Errorf("未获取到可见数据库列表")
}
func collectKingbaseNames(rows []map[string]interface{}, keys ...string) []string {
result := make([]string, 0, len(rows))
seen := make(map[string]struct{}, len(rows))
for _, row := range rows {
name := strings.TrimSpace(getKingbaseNameFromRow(row, keys...))
if name == "" {
continue
}
if _, exists := seen[name]; exists {
continue
}
seen[name] = struct{}{}
result = append(result, name)
}
return result
}
func getKingbaseNameFromRow(row map[string]interface{}, keys ...string) string {
if len(row) == 0 {
return ""
}
for _, key := range keys {
if value, ok := row[key]; ok {
return fmt.Sprintf("%v", value)
}
}
return dbs, nil
for existingKey, value := range row {
for _, key := range keys {
if strings.EqualFold(existingKey, key) {
return fmt.Sprintf("%v", value)
}
}
}
for _, value := range row {
return fmt.Sprintf("%v", value)
}
return ""
}
func (k *KingbaseDB) GetTables(dbName string) ([]string, error) {

View File

@@ -2,7 +2,39 @@
package db
import "testing"
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"sync"
"testing"
)
const fakeKingbaseDriverName = "gonavi-fake-kingbase"
var (
registerFakeKingbaseDriverOnce sync.Once
fakeKingbaseStateMu sync.Mutex
fakeKingbaseState = struct {
queryErr error
queryResults map[string]fakeKingbaseQueryResult
lastQuery string
queries []string
}{
lastQuery: "",
queryResults: map[string]fakeKingbaseQueryResult{},
queries: nil,
}
)
type fakeKingbaseQueryResult struct {
columns []string
rows [][]driver.Value
err error
}
func TestNormalizeKingbaseIdentifier(t *testing.T) {
tests := []struct {
@@ -115,3 +147,125 @@ func TestSplitKingbaseQualifiedTable(t *testing.T) {
})
}
}
func TestKingbaseGetDatabasesFallsBackToCurrentDatabase(t *testing.T) {
registerFakeKingbaseDriverOnce.Do(func() {
sql.Register(fakeKingbaseDriverName, fakeKingbaseDriver{})
})
db, err := sql.Open(fakeKingbaseDriverName, "")
if err != nil {
t.Fatalf("open fake kingbase db failed: %v", err)
}
defer db.Close()
const listSQL = "SELECT datname FROM pg_database WHERE datistemplate = false"
const fallbackSQL = "SELECT current_database() AS datname"
fakeKingbaseStateMu.Lock()
fakeKingbaseState.queryErr = nil
fakeKingbaseState.queryResults = map[string]fakeKingbaseQueryResult{
listSQL: {
err: errors.New("permission denied for relation pg_database"),
},
fallbackSQL: {
columns: []string{"datname"},
rows: [][]driver.Value{
{"demo"},
},
},
}
fakeKingbaseState.lastQuery = ""
fakeKingbaseState.queries = nil
fakeKingbaseStateMu.Unlock()
client := &KingbaseDB{conn: db}
databases, err := client.GetDatabases()
if err != nil {
t.Fatalf("expected GetDatabases to fallback, got err=%v", err)
}
if len(databases) != 1 || databases[0] != "demo" {
t.Fatalf("expected fallback database list, got %v", databases)
}
fakeKingbaseStateMu.Lock()
queries := append([]string(nil), fakeKingbaseState.queries...)
fakeKingbaseStateMu.Unlock()
if len(queries) != 2 {
t.Fatalf("expected two queries, got %v", queries)
}
if queries[0] != listSQL || queries[1] != fallbackSQL {
t.Fatalf("unexpected query order: %v", queries)
}
}
type fakeKingbaseDriver struct{}
func (fakeKingbaseDriver) Open(name string) (driver.Conn, error) {
return fakeKingbaseConn{}, nil
}
type fakeKingbaseConn struct{}
func (fakeKingbaseConn) Prepare(query string) (driver.Stmt, error) {
return nil, errors.New("prepare not implemented")
}
func (fakeKingbaseConn) Close() error {
return nil
}
func (fakeKingbaseConn) Begin() (driver.Tx, error) {
return nil, errors.New("transactions not implemented")
}
func (fakeKingbaseConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
fakeKingbaseStateMu.Lock()
defer fakeKingbaseStateMu.Unlock()
fakeKingbaseState.lastQuery = query
fakeKingbaseState.queries = append(fakeKingbaseState.queries, query)
if result, ok := fakeKingbaseState.queryResults[query]; ok {
if result.err != nil {
return nil, result.err
}
return &fakeKingbaseRows{columns: result.columns, rows: result.rows}, nil
}
if fakeKingbaseState.queryErr != nil {
return nil, fakeKingbaseState.queryErr
}
return &fakeKingbaseRows{}, nil
}
type fakeKingbaseRows struct {
columns []string
rows [][]driver.Value
index int
}
func (r *fakeKingbaseRows) Columns() []string {
if len(r.columns) > 0 {
return r.columns
}
return []string{"datname"}
}
func (r *fakeKingbaseRows) Close() error {
return nil
}
func (r *fakeKingbaseRows) Next(dest []driver.Value) error {
if r.index < len(r.rows) {
row := r.rows[r.index]
for idx := range dest {
if idx < len(row) {
dest[idx] = row[idx]
}
}
r.index++
return nil
}
if len(dest) > 0 {
dest[0] = strings.TrimSpace("demo")
}
return io.EOF
}