mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-10 00:19:40 +08:00
@@ -74,6 +74,62 @@ func normalizeMySQLAddress(host string, port int) string {
|
||||
return fmt.Sprintf("%s:%d", h, p)
|
||||
}
|
||||
|
||||
var mysqlDatabaseQueries = []string{
|
||||
"SHOW DATABASES",
|
||||
"SELECT DATABASE() AS `Database`",
|
||||
}
|
||||
|
||||
func collectMySQLDatabaseNames(queryFn func(string) ([]map[string]interface{}, []string, error)) ([]string, error) {
|
||||
if queryFn == nil {
|
||||
return nil, fmt.Errorf("查询函数为空")
|
||||
}
|
||||
|
||||
names := make([]string, 0, 8)
|
||||
seen := make(map[string]struct{}, 8)
|
||||
var lastErr error
|
||||
|
||||
appendNames := func(rows []map[string]interface{}) {
|
||||
for _, row := range rows {
|
||||
for _, key := range []string{"Database", "database"} {
|
||||
val, ok := row[key]
|
||||
if !ok || val == nil {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||||
if name == "" || strings.EqualFold(name, "<nil>") {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[name]; exists {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
names = append(names, name)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, sqlText := range mysqlDatabaseQueries {
|
||||
rows, _, err := queryFn(sqlText)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
appendNames(rows)
|
||||
if len(names) > 0 {
|
||||
return names, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(names) > 0 {
|
||||
return names, nil
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, fmt.Errorf("未获取到可用数据库")
|
||||
}
|
||||
|
||||
func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
uriText := strings.TrimSpace(config.URI)
|
||||
if uriText == "" {
|
||||
@@ -364,19 +420,7 @@ func (m *MySQLDB) Exec(query string) (int64, error) {
|
||||
}
|
||||
|
||||
func (m *MySQLDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := m.Query("SHOW DATABASES")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dbs []string
|
||||
for _, row := range data {
|
||||
if val, ok := row["Database"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
} else if val, ok := row["database"]; ok {
|
||||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
return dbs, nil
|
||||
return collectMySQLDatabaseNames(m.Query)
|
||||
}
|
||||
|
||||
func (m *MySQLDB) GetTables(dbName string) ([]string, error) {
|
||||
|
||||
84
internal/db/mysql_metadata_test.go
Normal file
84
internal/db/mysql_metadata_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCollectMySQLDatabaseNames_FallsBackToCurrentDatabase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := collectMySQLDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
|
||||
switch query {
|
||||
case mysqlDatabaseQueries[0]:
|
||||
return nil, nil, errors.New("Error 1227 (42000): Access denied; you need (at least one of) the SHOW DATABASES privilege(s) for this operation")
|
||||
case mysqlDatabaseQueries[1]:
|
||||
return []map[string]interface{}{
|
||||
{"Database": "biz_app"},
|
||||
}, nil, nil
|
||||
default:
|
||||
return nil, nil, errors.New("unexpected query")
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("collectMySQLDatabaseNames 返回错误: %v", err)
|
||||
}
|
||||
|
||||
want := []string{"biz_app"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected database names, got=%v want=%v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectMySQLDatabaseNames_PrefersShowDatabasesWhenAvailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := collectMySQLDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
|
||||
switch query {
|
||||
case mysqlDatabaseQueries[0]:
|
||||
return []map[string]interface{}{
|
||||
{"Database": "analytics"},
|
||||
{"database": "audit"},
|
||||
}, nil, nil
|
||||
case mysqlDatabaseQueries[1]:
|
||||
return []map[string]interface{}{
|
||||
{"Database": "should_not_be_used"},
|
||||
}, nil, nil
|
||||
default:
|
||||
return nil, nil, errors.New("unexpected query")
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("collectMySQLDatabaseNames 返回错误: %v", err)
|
||||
}
|
||||
|
||||
want := []string{"analytics", "audit"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected database names, got=%v want=%v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectMySQLDatabaseNames_ReturnsOriginalErrorWhenNoDatabaseResolved(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectErr := errors.New("show databases denied")
|
||||
got, err := collectMySQLDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
|
||||
switch query {
|
||||
case mysqlDatabaseQueries[0]:
|
||||
return nil, nil, expectErr
|
||||
case mysqlDatabaseQueries[1]:
|
||||
return []map[string]interface{}{
|
||||
{"Database": nil},
|
||||
}, nil, nil
|
||||
default:
|
||||
return nil, nil, errors.New("unexpected query")
|
||||
}
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("期望返回错误,实际 got=%v", got)
|
||||
}
|
||||
if !errors.Is(err, expectErr) {
|
||||
t.Fatalf("错误不符合预期: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user