🐛 fix(shardingsphere): 修复代理分片表展示为物理表

- 元数据取表接入 ShardingSphere 逻辑表规则

- 兼容 PostgreSQL、MySQL、MariaDB 协议入口

- 补充分片表折叠和降级测试

Refs #410
This commit is contained in:
Syngnat
2026-05-24 12:00:48 +08:00
parent 85a0f9d007
commit d414a38877
5 changed files with 325 additions and 12 deletions

View File

@@ -210,7 +210,7 @@ func (m *MariaDB) GetTables(dbName string) ([]string, error) {
break
}
}
return tables, nil
return resolveShardingSphereLogicalTables(tables, m.Query), nil
}
func (m *MariaDB) GetCreateStatement(dbName, tableName string) (string, error) {

View File

@@ -801,7 +801,7 @@ func (m *MySQLDB) GetTables(dbName string) ([]string, error) {
break
}
}
return tables, nil
return resolveShardingSphereLogicalTables(tables, m.Query), nil
}
func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) {

View File

@@ -297,25 +297,57 @@ func (p *PostgresDB) GetDatabases() ([]string, error) {
}
func (p *PostgresDB) GetTables(dbName string) ([]string, error) {
query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY schemaname, tablename"
query := buildPostgresTablesQuery()
data, _, err := p.Query(query)
if err != nil {
return nil, err
data, _, err = p.Query(buildPostgresLegacyTablesQuery())
if err != nil {
return nil, err
}
}
var tables []string
tables := parsePostgresTableNames(data)
return resolveShardingSphereLogicalTables(tables, p.Query), nil
}
func buildPostgresTablesQuery() string {
return `
SELECT DISTINCT
n.nspname AS schemaname,
c.relname AS tablename
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r', 'p')
AND n.nspname != 'information_schema'
AND n.nspname NOT LIKE 'pg|_%' ESCAPE '|'
ORDER BY n.nspname, c.relname`
}
func buildPostgresLegacyTablesQuery() string {
return "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY schemaname, tablename"
}
func parsePostgresTableNames(data []map[string]interface{}) []string {
tables := make([]string, 0, len(data))
seen := make(map[string]struct{}, len(data))
for _, row := range data {
schema, okSchema := row["schemaname"]
name, okName := row["tablename"]
if okSchema && okName {
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
schema := getCaseInsensitiveRowString(row, "schemaname", "schema_name", "schema", "nspname")
name := getCaseInsensitiveRowString(row, "tablename", "table_name", "relname", "name")
if name == "" {
continue
}
if okName {
tables = append(tables, fmt.Sprintf("%v", name))
table := name
if schema != "" {
table = fmt.Sprintf("%s.%s", schema, name)
}
key := strings.ToLower(table)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
tables = append(tables, table)
}
return tables, nil
return tables
}
func (p *PostgresDB) GetCreateStatement(dbName, tableName string) (string, error) {

View File

@@ -0,0 +1,165 @@
package db
import (
"fmt"
"strings"
)
const shardingSphereTableRulesQuery = "SHOW SHARDING TABLE RULES"
type tableMetadataQueryFunc func(string) ([]map[string]interface{}, []string, error)
func resolveShardingSphereLogicalTables(tables []string, query tableMetadataQueryFunc) []string {
if len(tables) == 0 || query == nil || !hasNumericShardTableCandidates(tables) {
return tables
}
rulesData, _, err := query(shardingSphereTableRulesQuery)
if err != nil {
return tables
}
return mergeShardingSphereLogicalTables(tables, rulesData)
}
func getCaseInsensitiveRowString(row map[string]interface{}, keys ...string) string {
if len(row) == 0 {
return ""
}
values := make(map[string]interface{}, len(row))
for key, value := range row {
values[strings.ToLower(key)] = value
}
for _, key := range keys {
value, ok := values[strings.ToLower(key)]
if !ok || value == nil {
continue
}
text := strings.TrimSpace(fmt.Sprintf("%v", value))
if text != "" && !strings.EqualFold(text, "<nil>") && !strings.EqualFold(text, "null") {
return text
}
}
return ""
}
func hasNumericShardTableCandidates(tables []string) bool {
countByBase := make(map[string]int)
for _, table := range tables {
schema, name := splitQualifiedTableName(table)
base := trimNumericShardSuffix(name)
if base == "" || base == name {
continue
}
key := strings.ToLower(schema + "." + base)
countByBase[key]++
if countByBase[key] > 1 {
return true
}
}
return false
}
func mergeShardingSphereLogicalTables(tables []string, rulesData []map[string]interface{}) []string {
logicalTables := make([]string, 0, len(rulesData))
for _, row := range rulesData {
logical := getCaseInsensitiveRowString(row, "table", "table_name", "logic_table", "logical_table", "logical_table_name")
if logical == "" {
continue
}
logicalTables = append(logicalTables, logical)
}
if len(logicalTables) == 0 {
return tables
}
result := make([]string, 0, len(tables))
seen := make(map[string]struct{}, len(tables))
add := func(table string) {
if table == "" {
return
}
key := strings.ToLower(table)
if _, exists := seen[key]; exists {
return
}
seen[key] = struct{}{}
result = append(result, table)
}
for _, table := range tables {
schema, name := splitQualifiedTableName(table)
replacement := ""
for _, logical := range logicalTables {
logicalSchema, logicalName := splitQualifiedTableName(logical)
if logicalName == "" {
continue
}
if logicalSchema != "" && schema != "" && !strings.EqualFold(schema, logicalSchema) {
continue
}
if strings.EqualFold(name, logicalName) || isPhysicalShardOfLogicalTable(name, logicalName) {
if logicalSchema != "" {
replacement = logical
} else if schema != "" {
replacement = fmt.Sprintf("%s.%s", schema, logicalName)
} else {
replacement = logicalName
}
break
}
}
if replacement != "" {
add(replacement)
continue
}
add(table)
}
return result
}
func splitQualifiedTableName(table string) (string, string) {
raw := strings.TrimSpace(table)
if raw == "" {
return "", ""
}
idx := strings.LastIndex(raw, ".")
if idx <= 0 || idx >= len(raw)-1 {
return "", raw
}
return strings.TrimSpace(raw[:idx]), strings.TrimSpace(raw[idx+1:])
}
func isPhysicalShardOfLogicalTable(name, logicalName string) bool {
name = strings.TrimSpace(name)
logicalName = strings.TrimSpace(logicalName)
if name == "" || logicalName == "" {
return false
}
if !strings.HasPrefix(strings.ToLower(name), strings.ToLower(logicalName)+"_") {
return false
}
suffix := name[len(logicalName)+1:]
if suffix == "" {
return false
}
for _, r := range suffix {
if r < '0' || r > '9' {
return false
}
}
return true
}
func trimNumericShardSuffix(name string) string {
trimmed := strings.TrimSpace(name)
idx := strings.LastIndex(trimmed, "_")
if idx <= 0 || idx >= len(trimmed)-1 {
return trimmed
}
for _, r := range trimmed[idx+1:] {
if r < '0' || r > '9' {
return trimmed
}
}
return trimmed[:idx]
}

View File

@@ -0,0 +1,116 @@
package db
import (
"errors"
"reflect"
"testing"
)
func TestMergeShardingSphereLogicalTablesCollapsesNumericPhysicalTables(t *testing.T) {
t.Parallel()
tables := []string{
"public.apply_or_report_file_0",
"public.apply_or_report_file_1",
"public.apply_or_report_filesystem",
"public.ai_result_0",
"public.ai_result_1",
}
rules := []map[string]interface{}{
{"table": "apply_or_report_file"},
{"table": "ai_result"},
}
got := mergeShardingSphereLogicalTables(tables, rules)
want := []string{
"public.apply_or_report_file",
"public.apply_or_report_filesystem",
"public.ai_result",
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("merged tables = %v, want %v", got, want)
}
}
func TestMergeShardingSphereLogicalTablesPreservesQualifiedLogicalRule(t *testing.T) {
t.Parallel()
tables := []string{"public.orders_0", "public.orders_1", "archive.orders_0", "archive.orders_1"}
rules := []map[string]interface{}{
{"logical_table_name": "archive.orders"},
}
got := mergeShardingSphereLogicalTables(tables, rules)
want := []string{"public.orders_0", "public.orders_1", "archive.orders"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("merged tables = %v, want %v", got, want)
}
}
func TestResolveShardingSphereLogicalTablesSkipsDistSQLWithoutShardCandidates(t *testing.T) {
t.Parallel()
called := false
tables := []string{"orders_0", "audit_log", "users"}
got := resolveShardingSphereLogicalTables(tables, func(string) ([]map[string]interface{}, []string, error) {
called = true
return nil, nil, nil
})
if called {
t.Fatalf("DistSQL should not run without multiple numeric shard candidates")
}
if !reflect.DeepEqual(got, tables) {
t.Fatalf("tables = %v, want %v", got, tables)
}
}
func TestResolveShardingSphereLogicalTablesUsesRulesOnlyWhenAvailable(t *testing.T) {
t.Parallel()
tables := []string{"orders_0", "orders_1"}
got := resolveShardingSphereLogicalTables(tables, func(query string) ([]map[string]interface{}, []string, error) {
if query != shardingSphereTableRulesQuery {
t.Fatalf("query = %q, want %q", query, shardingSphereTableRulesQuery)
}
return nil, nil, errors.New("not a ShardingSphere proxy")
})
if !reflect.DeepEqual(got, tables) {
t.Fatalf("tables should be preserved when DistSQL fails, got %v", got)
}
}
func TestResolveShardingSphereLogicalTablesCollapsesFromDistSQL(t *testing.T) {
t.Parallel()
tables := []string{"orders_0", "orders_1", "users"}
got := resolveShardingSphereLogicalTables(tables, func(query string) ([]map[string]interface{}, []string, error) {
return []map[string]interface{}{
{"TABLE": "orders"},
}, []string{"TABLE"}, nil
})
want := []string{"orders", "users"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("tables = %v, want %v", got, want)
}
}
func TestParsePostgresTableNamesUsesCaseInsensitiveColumns(t *testing.T) {
t.Parallel()
got := parsePostgresTableNames([]map[string]interface{}{
{"SCHEMANAME": "public", "TABLENAME": "orders"},
{"schema_name": "archive", "table_name": "orders"},
{"schema_name": "archive", "table_name": "orders"},
})
want := []string{"public.orders", "archive.orders"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("parsed tables = %v, want %v", got, want)
}
}