mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-06 20:03:05 +08:00
@@ -304,18 +304,70 @@ func (k *KingbaseDB) Exec(query string) (int64, error) {
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetDatabases() ([]string, error) {
|
||||
// Postgres/Kingbase style
|
||||
data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||||
if err == nil {
|
||||
dbs := collectKingbaseNames(data, "datname", "database")
|
||||
if len(dbs) > 0 {
|
||||
return dbs, nil
|
||||
}
|
||||
}
|
||||
|
||||
fallbackData, _, fallbackErr := k.Query("SELECT current_database() AS datname")
|
||||
if fallbackErr != nil {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fallbackErr
|
||||
}
|
||||
|
||||
dbs := collectKingbaseNames(fallbackData, "datname", "database", "current_database", "currentDatabase")
|
||||
if len(dbs) > 0 {
|
||||
return dbs, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["datname"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
return nil, fmt.Errorf("未获取到可见数据库列表")
|
||||
}
|
||||
|
||||
func collectKingbaseNames(rows []map[string]interface{}, keys ...string) []string {
|
||||
result := make([]string, 0, len(rows))
|
||||
seen := make(map[string]struct{}, len(rows))
|
||||
for _, row := range rows {
|
||||
name := strings.TrimSpace(getKingbaseNameFromRow(row, keys...))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[name]; exists {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
result = append(result, name)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func getKingbaseNameFromRow(row map[string]interface{}, keys ...string) string {
|
||||
if len(row) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, key := range keys {
|
||||
if value, ok := row[key]; ok {
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
for existingKey, value := range row {
|
||||
for _, key := range keys {
|
||||
if strings.EqualFold(existingKey, key) {
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, value := range row {
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetTables(dbName string) ([]string, error) {
|
||||
|
||||
@@ -2,7 +2,39 @@
|
||||
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const fakeKingbaseDriverName = "gonavi-fake-kingbase"
|
||||
|
||||
var (
|
||||
registerFakeKingbaseDriverOnce sync.Once
|
||||
fakeKingbaseStateMu sync.Mutex
|
||||
fakeKingbaseState = struct {
|
||||
queryErr error
|
||||
queryResults map[string]fakeKingbaseQueryResult
|
||||
lastQuery string
|
||||
queries []string
|
||||
}{
|
||||
lastQuery: "",
|
||||
queryResults: map[string]fakeKingbaseQueryResult{},
|
||||
queries: nil,
|
||||
}
|
||||
)
|
||||
|
||||
type fakeKingbaseQueryResult struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
err error
|
||||
}
|
||||
|
||||
func TestNormalizeKingbaseIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -115,3 +147,125 @@ func TestSplitKingbaseQualifiedTable(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKingbaseGetDatabasesFallsBackToCurrentDatabase(t *testing.T) {
|
||||
registerFakeKingbaseDriverOnce.Do(func() {
|
||||
sql.Register(fakeKingbaseDriverName, fakeKingbaseDriver{})
|
||||
})
|
||||
|
||||
db, err := sql.Open(fakeKingbaseDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open fake kingbase db failed: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
const listSQL = "SELECT datname FROM pg_database WHERE datistemplate = false"
|
||||
const fallbackSQL = "SELECT current_database() AS datname"
|
||||
|
||||
fakeKingbaseStateMu.Lock()
|
||||
fakeKingbaseState.queryErr = nil
|
||||
fakeKingbaseState.queryResults = map[string]fakeKingbaseQueryResult{
|
||||
listSQL: {
|
||||
err: errors.New("permission denied for relation pg_database"),
|
||||
},
|
||||
fallbackSQL: {
|
||||
columns: []string{"datname"},
|
||||
rows: [][]driver.Value{
|
||||
{"demo"},
|
||||
},
|
||||
},
|
||||
}
|
||||
fakeKingbaseState.lastQuery = ""
|
||||
fakeKingbaseState.queries = nil
|
||||
fakeKingbaseStateMu.Unlock()
|
||||
|
||||
client := &KingbaseDB{conn: db}
|
||||
databases, err := client.GetDatabases()
|
||||
if err != nil {
|
||||
t.Fatalf("expected GetDatabases to fallback, got err=%v", err)
|
||||
}
|
||||
if len(databases) != 1 || databases[0] != "demo" {
|
||||
t.Fatalf("expected fallback database list, got %v", databases)
|
||||
}
|
||||
|
||||
fakeKingbaseStateMu.Lock()
|
||||
queries := append([]string(nil), fakeKingbaseState.queries...)
|
||||
fakeKingbaseStateMu.Unlock()
|
||||
if len(queries) != 2 {
|
||||
t.Fatalf("expected two queries, got %v", queries)
|
||||
}
|
||||
if queries[0] != listSQL || queries[1] != fallbackSQL {
|
||||
t.Fatalf("unexpected query order: %v", queries)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeKingbaseDriver struct{}
|
||||
|
||||
func (fakeKingbaseDriver) Open(name string) (driver.Conn, error) {
|
||||
return fakeKingbaseConn{}, nil
|
||||
}
|
||||
|
||||
type fakeKingbaseConn struct{}
|
||||
|
||||
func (fakeKingbaseConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return nil, errors.New("prepare not implemented")
|
||||
}
|
||||
|
||||
func (fakeKingbaseConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fakeKingbaseConn) Begin() (driver.Tx, error) {
|
||||
return nil, errors.New("transactions not implemented")
|
||||
}
|
||||
|
||||
func (fakeKingbaseConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
fakeKingbaseStateMu.Lock()
|
||||
defer fakeKingbaseStateMu.Unlock()
|
||||
fakeKingbaseState.lastQuery = query
|
||||
fakeKingbaseState.queries = append(fakeKingbaseState.queries, query)
|
||||
if result, ok := fakeKingbaseState.queryResults[query]; ok {
|
||||
if result.err != nil {
|
||||
return nil, result.err
|
||||
}
|
||||
return &fakeKingbaseRows{columns: result.columns, rows: result.rows}, nil
|
||||
}
|
||||
if fakeKingbaseState.queryErr != nil {
|
||||
return nil, fakeKingbaseState.queryErr
|
||||
}
|
||||
return &fakeKingbaseRows{}, nil
|
||||
}
|
||||
|
||||
type fakeKingbaseRows struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
index int
|
||||
}
|
||||
|
||||
func (r *fakeKingbaseRows) Columns() []string {
|
||||
if len(r.columns) > 0 {
|
||||
return r.columns
|
||||
}
|
||||
return []string{"datname"}
|
||||
}
|
||||
|
||||
func (r *fakeKingbaseRows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeKingbaseRows) Next(dest []driver.Value) error {
|
||||
if r.index < len(r.rows) {
|
||||
row := r.rows[r.index]
|
||||
for idx := range dest {
|
||||
if idx < len(row) {
|
||||
dest[idx] = row[idx]
|
||||
}
|
||||
}
|
||||
r.index++
|
||||
return nil
|
||||
}
|
||||
if len(dest) > 0 {
|
||||
dest[0] = strings.TrimSpace("demo")
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user