From d414a3887754b3317f9fbc517e2b690eeda140f3 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sun, 24 May 2026 12:00:48 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(shardingsphere):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=90=86=E5=88=86=E7=89=87=E8=A1=A8=E5=B1=95?= =?UTF-8?q?=E7=A4=BA=E4=B8=BA=E7=89=A9=E7=90=86=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 元数据取表接入 ShardingSphere 逻辑表规则 - 兼容 PostgreSQL、MySQL、MariaDB 协议入口 - 补充分片表折叠和降级测试 Refs #410 --- internal/db/mariadb_impl.go | 2 +- internal/db/mysql_impl.go | 2 +- internal/db/postgres_impl.go | 52 +++++-- internal/db/shardingsphere_tables.go | 165 ++++++++++++++++++++++ internal/db/shardingsphere_tables_test.go | 116 +++++++++++++++ 5 files changed, 325 insertions(+), 12 deletions(-) create mode 100644 internal/db/shardingsphere_tables.go create mode 100644 internal/db/shardingsphere_tables_test.go diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 5e1d29d..e9df19b 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -210,7 +210,7 @@ func (m *MariaDB) GetTables(dbName string) ([]string, error) { break } } - return tables, nil + return resolveShardingSphereLogicalTables(tables, m.Query), nil } func (m *MariaDB) GetCreateStatement(dbName, tableName string) (string, error) { diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 70eb6f8..d859ebd 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -801,7 +801,7 @@ func (m *MySQLDB) GetTables(dbName string) ([]string, error) { break } } - return tables, nil + return resolveShardingSphereLogicalTables(tables, m.Query), nil } func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) { diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index 989a3d1..ad418e4 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -297,25 +297,57 @@ func (p *PostgresDB) GetDatabases() ([]string, error) { } func (p *PostgresDB) GetTables(dbName string) ([]string, error) { - query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY schemaname, tablename" + query := buildPostgresTablesQuery() data, _, err := p.Query(query) if err != nil { - return nil, err + data, _, err = p.Query(buildPostgresLegacyTablesQuery()) + if err != nil { + return nil, err + } } - var tables []string + tables := parsePostgresTableNames(data) + return resolveShardingSphereLogicalTables(tables, p.Query), nil +} + +func buildPostgresTablesQuery() string { + return ` +SELECT DISTINCT + n.nspname AS schemaname, + c.relname AS tablename +FROM pg_catalog.pg_class c +JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace +WHERE c.relkind IN ('r', 'p') + AND n.nspname != 'information_schema' + AND n.nspname NOT LIKE 'pg|_%' ESCAPE '|' +ORDER BY n.nspname, c.relname` +} + +func buildPostgresLegacyTablesQuery() string { + return "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY schemaname, tablename" +} + +func parsePostgresTableNames(data []map[string]interface{}) []string { + tables := make([]string, 0, len(data)) + seen := make(map[string]struct{}, len(data)) for _, row := range data { - schema, okSchema := row["schemaname"] - name, okName := row["tablename"] - if okSchema && okName { - tables = append(tables, fmt.Sprintf("%v.%v", schema, name)) + schema := getCaseInsensitiveRowString(row, "schemaname", "schema_name", "schema", "nspname") + name := getCaseInsensitiveRowString(row, "tablename", "table_name", "relname", "name") + if name == "" { continue } - if okName { - tables = append(tables, fmt.Sprintf("%v", name)) + table := name + if schema != "" { + table = fmt.Sprintf("%s.%s", schema, name) } + key := strings.ToLower(table) + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + tables = append(tables, table) } - return tables, nil + return tables } func (p *PostgresDB) GetCreateStatement(dbName, tableName string) (string, error) { diff --git a/internal/db/shardingsphere_tables.go b/internal/db/shardingsphere_tables.go new file mode 100644 index 0000000..371475f --- /dev/null +++ b/internal/db/shardingsphere_tables.go @@ -0,0 +1,165 @@ +package db + +import ( + "fmt" + "strings" +) + +const shardingSphereTableRulesQuery = "SHOW SHARDING TABLE RULES" + +type tableMetadataQueryFunc func(string) ([]map[string]interface{}, []string, error) + +func resolveShardingSphereLogicalTables(tables []string, query tableMetadataQueryFunc) []string { + if len(tables) == 0 || query == nil || !hasNumericShardTableCandidates(tables) { + return tables + } + + rulesData, _, err := query(shardingSphereTableRulesQuery) + if err != nil { + return tables + } + return mergeShardingSphereLogicalTables(tables, rulesData) +} + +func getCaseInsensitiveRowString(row map[string]interface{}, keys ...string) string { + if len(row) == 0 { + return "" + } + values := make(map[string]interface{}, len(row)) + for key, value := range row { + values[strings.ToLower(key)] = value + } + for _, key := range keys { + value, ok := values[strings.ToLower(key)] + if !ok || value == nil { + continue + } + text := strings.TrimSpace(fmt.Sprintf("%v", value)) + if text != "" && !strings.EqualFold(text, "") && !strings.EqualFold(text, "null") { + return text + } + } + return "" +} + +func hasNumericShardTableCandidates(tables []string) bool { + countByBase := make(map[string]int) + for _, table := range tables { + schema, name := splitQualifiedTableName(table) + base := trimNumericShardSuffix(name) + if base == "" || base == name { + continue + } + key := strings.ToLower(schema + "." + base) + countByBase[key]++ + if countByBase[key] > 1 { + return true + } + } + return false +} + +func mergeShardingSphereLogicalTables(tables []string, rulesData []map[string]interface{}) []string { + logicalTables := make([]string, 0, len(rulesData)) + for _, row := range rulesData { + logical := getCaseInsensitiveRowString(row, "table", "table_name", "logic_table", "logical_table", "logical_table_name") + if logical == "" { + continue + } + logicalTables = append(logicalTables, logical) + } + if len(logicalTables) == 0 { + return tables + } + + result := make([]string, 0, len(tables)) + seen := make(map[string]struct{}, len(tables)) + add := func(table string) { + if table == "" { + return + } + key := strings.ToLower(table) + if _, exists := seen[key]; exists { + return + } + seen[key] = struct{}{} + result = append(result, table) + } + + for _, table := range tables { + schema, name := splitQualifiedTableName(table) + replacement := "" + for _, logical := range logicalTables { + logicalSchema, logicalName := splitQualifiedTableName(logical) + if logicalName == "" { + continue + } + if logicalSchema != "" && schema != "" && !strings.EqualFold(schema, logicalSchema) { + continue + } + if strings.EqualFold(name, logicalName) || isPhysicalShardOfLogicalTable(name, logicalName) { + if logicalSchema != "" { + replacement = logical + } else if schema != "" { + replacement = fmt.Sprintf("%s.%s", schema, logicalName) + } else { + replacement = logicalName + } + break + } + } + if replacement != "" { + add(replacement) + continue + } + add(table) + } + return result +} + +func splitQualifiedTableName(table string) (string, string) { + raw := strings.TrimSpace(table) + if raw == "" { + return "", "" + } + idx := strings.LastIndex(raw, ".") + if idx <= 0 || idx >= len(raw)-1 { + return "", raw + } + return strings.TrimSpace(raw[:idx]), strings.TrimSpace(raw[idx+1:]) +} + +func isPhysicalShardOfLogicalTable(name, logicalName string) bool { + name = strings.TrimSpace(name) + logicalName = strings.TrimSpace(logicalName) + if name == "" || logicalName == "" { + return false + } + if !strings.HasPrefix(strings.ToLower(name), strings.ToLower(logicalName)+"_") { + return false + } + suffix := name[len(logicalName)+1:] + if suffix == "" { + return false + } + for _, r := range suffix { + if r < '0' || r > '9' { + return false + } + } + return true +} + +func trimNumericShardSuffix(name string) string { + trimmed := strings.TrimSpace(name) + idx := strings.LastIndex(trimmed, "_") + if idx <= 0 || idx >= len(trimmed)-1 { + return trimmed + } + for _, r := range trimmed[idx+1:] { + if r < '0' || r > '9' { + return trimmed + } + } + return trimmed[:idx] +} diff --git a/internal/db/shardingsphere_tables_test.go b/internal/db/shardingsphere_tables_test.go new file mode 100644 index 0000000..2667e8a --- /dev/null +++ b/internal/db/shardingsphere_tables_test.go @@ -0,0 +1,116 @@ +package db + +import ( + "errors" + "reflect" + "testing" +) + +func TestMergeShardingSphereLogicalTablesCollapsesNumericPhysicalTables(t *testing.T) { + t.Parallel() + + tables := []string{ + "public.apply_or_report_file_0", + "public.apply_or_report_file_1", + "public.apply_or_report_filesystem", + "public.ai_result_0", + "public.ai_result_1", + } + rules := []map[string]interface{}{ + {"table": "apply_or_report_file"}, + {"table": "ai_result"}, + } + + got := mergeShardingSphereLogicalTables(tables, rules) + want := []string{ + "public.apply_or_report_file", + "public.apply_or_report_filesystem", + "public.ai_result", + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("merged tables = %v, want %v", got, want) + } +} + +func TestMergeShardingSphereLogicalTablesPreservesQualifiedLogicalRule(t *testing.T) { + t.Parallel() + + tables := []string{"public.orders_0", "public.orders_1", "archive.orders_0", "archive.orders_1"} + rules := []map[string]interface{}{ + {"logical_table_name": "archive.orders"}, + } + + got := mergeShardingSphereLogicalTables(tables, rules) + want := []string{"public.orders_0", "public.orders_1", "archive.orders"} + + if !reflect.DeepEqual(got, want) { + t.Fatalf("merged tables = %v, want %v", got, want) + } +} + +func TestResolveShardingSphereLogicalTablesSkipsDistSQLWithoutShardCandidates(t *testing.T) { + t.Parallel() + + called := false + tables := []string{"orders_0", "audit_log", "users"} + + got := resolveShardingSphereLogicalTables(tables, func(string) ([]map[string]interface{}, []string, error) { + called = true + return nil, nil, nil + }) + + if called { + t.Fatalf("DistSQL should not run without multiple numeric shard candidates") + } + if !reflect.DeepEqual(got, tables) { + t.Fatalf("tables = %v, want %v", got, tables) + } +} + +func TestResolveShardingSphereLogicalTablesUsesRulesOnlyWhenAvailable(t *testing.T) { + t.Parallel() + + tables := []string{"orders_0", "orders_1"} + got := resolveShardingSphereLogicalTables(tables, func(query string) ([]map[string]interface{}, []string, error) { + if query != shardingSphereTableRulesQuery { + t.Fatalf("query = %q, want %q", query, shardingSphereTableRulesQuery) + } + return nil, nil, errors.New("not a ShardingSphere proxy") + }) + + if !reflect.DeepEqual(got, tables) { + t.Fatalf("tables should be preserved when DistSQL fails, got %v", got) + } +} + +func TestResolveShardingSphereLogicalTablesCollapsesFromDistSQL(t *testing.T) { + t.Parallel() + + tables := []string{"orders_0", "orders_1", "users"} + got := resolveShardingSphereLogicalTables(tables, func(query string) ([]map[string]interface{}, []string, error) { + return []map[string]interface{}{ + {"TABLE": "orders"}, + }, []string{"TABLE"}, nil + }) + + want := []string{"orders", "users"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("tables = %v, want %v", got, want) + } +} + +func TestParsePostgresTableNamesUsesCaseInsensitiveColumns(t *testing.T) { + t.Parallel() + + got := parsePostgresTableNames([]map[string]interface{}{ + {"SCHEMANAME": "public", "TABLENAME": "orders"}, + {"schema_name": "archive", "table_name": "orders"}, + {"schema_name": "archive", "table_name": "orders"}, + }) + want := []string{"public.orders", "archive.orders"} + + if !reflect.DeepEqual(got, want) { + t.Fatalf("parsed tables = %v, want %v", got, want) + } +}