🐛 fix(mysql): 回退当前数据库列表查询

Fixes #327
This commit is contained in:
Syngnat
2026-04-11 21:53:52 +08:00
parent 89d79ff10c
commit fb500ee33b
3 changed files with 161 additions and 14 deletions

View File

@@ -74,6 +74,62 @@ func normalizeMySQLAddress(host string, port int) string {
return fmt.Sprintf("%s:%d", h, p)
}
var mysqlDatabaseQueries = []string{
"SHOW DATABASES",
"SELECT DATABASE() AS `Database`",
}
func collectMySQLDatabaseNames(queryFn func(string) ([]map[string]interface{}, []string, error)) ([]string, error) {
if queryFn == nil {
return nil, fmt.Errorf("查询函数为空")
}
names := make([]string, 0, 8)
seen := make(map[string]struct{}, 8)
var lastErr error
appendNames := func(rows []map[string]interface{}) {
for _, row := range rows {
for _, key := range []string{"Database", "database"} {
val, ok := row[key]
if !ok || val == nil {
continue
}
name := strings.TrimSpace(fmt.Sprintf("%v", val))
if name == "" || strings.EqualFold(name, "<nil>") {
continue
}
if _, exists := seen[name]; exists {
continue
}
seen[name] = struct{}{}
names = append(names, name)
break
}
}
}
for _, sqlText := range mysqlDatabaseQueries {
rows, _, err := queryFn(sqlText)
if err != nil {
lastErr = err
continue
}
appendNames(rows)
if len(names) > 0 {
return names, nil
}
}
if len(names) > 0 {
return names, nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("未获取到可用数据库")
}
func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
@@ -364,19 +420,7 @@ func (m *MySQLDB) Exec(query string) (int64, error) {
}
func (m *MySQLDB) GetDatabases() ([]string, error) {
data, _, err := m.Query("SHOW DATABASES")
if err != nil {
return nil, err
}
var dbs []string
for _, row := range data {
if val, ok := row["Database"]; ok {
dbs = append(dbs, fmt.Sprintf("%v", val))
} else if val, ok := row["database"]; ok {
dbs = append(dbs, fmt.Sprintf("%v", val))
}
}
return dbs, nil
return collectMySQLDatabaseNames(m.Query)
}
func (m *MySQLDB) GetTables(dbName string) ([]string, error) {

View File

@@ -0,0 +1,84 @@
package db
import (
"errors"
"reflect"
"testing"
)
func TestCollectMySQLDatabaseNames_FallsBackToCurrentDatabase(t *testing.T) {
t.Parallel()
got, err := collectMySQLDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
switch query {
case mysqlDatabaseQueries[0]:
return nil, nil, errors.New("Error 1227 (42000): Access denied; you need (at least one of) the SHOW DATABASES privilege(s) for this operation")
case mysqlDatabaseQueries[1]:
return []map[string]interface{}{
{"Database": "biz_app"},
}, nil, nil
default:
return nil, nil, errors.New("unexpected query")
}
})
if err != nil {
t.Fatalf("collectMySQLDatabaseNames 返回错误: %v", err)
}
want := []string{"biz_app"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected database names, got=%v want=%v", got, want)
}
}
func TestCollectMySQLDatabaseNames_PrefersShowDatabasesWhenAvailable(t *testing.T) {
t.Parallel()
got, err := collectMySQLDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
switch query {
case mysqlDatabaseQueries[0]:
return []map[string]interface{}{
{"Database": "analytics"},
{"database": "audit"},
}, nil, nil
case mysqlDatabaseQueries[1]:
return []map[string]interface{}{
{"Database": "should_not_be_used"},
}, nil, nil
default:
return nil, nil, errors.New("unexpected query")
}
})
if err != nil {
t.Fatalf("collectMySQLDatabaseNames 返回错误: %v", err)
}
want := []string{"analytics", "audit"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected database names, got=%v want=%v", got, want)
}
}
func TestCollectMySQLDatabaseNames_ReturnsOriginalErrorWhenNoDatabaseResolved(t *testing.T) {
t.Parallel()
expectErr := errors.New("show databases denied")
got, err := collectMySQLDatabaseNames(func(query string) ([]map[string]interface{}, []string, error) {
switch query {
case mysqlDatabaseQueries[0]:
return nil, nil, expectErr
case mysqlDatabaseQueries[1]:
return []map[string]interface{}{
{"Database": nil},
}, nil, nil
default:
return nil, nil, errors.New("unexpected query")
}
})
if err == nil {
t.Fatalf("期望返回错误,实际 got=%v", got)
}
if !errors.Is(err, expectErr) {
t.Fatalf("错误不符合预期: %v", err)
}
}