diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index f46359e..3c89e7a 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -7,7 +7,7 @@ func init() { "mariadb": "src-4e1ec648c70c87ea", "oceanbase": "src-8e445fc4899d850f", "diros": "src-74927b3809258666", - "starrocks": "src-3b5aad8a32f79b61", + "starrocks": "src-4ea05ce44321c17b", "sphinx": "src-269bd60a34df47d3", "sqlserver": "src-84553484c72e7253", "sqlite": "src-762863d48f653b89", diff --git a/internal/db/starrocks_impl.go b/internal/db/starrocks_impl.go index 64286ef..f25167a 100644 --- a/internal/db/starrocks_impl.go +++ b/internal/db/starrocks_impl.go @@ -126,6 +126,81 @@ func collectStarRocksAddresses(config connection.ConnectionConfig) []string { return result } +func starRocksMetadataLiteral(value string) string { + return "'" + strings.ReplaceAll(value, "'", "''") + "'" +} + +func buildStarRocksColumnsQuery(dbName, tableName string) string { + schemaPredicate := "TABLE_SCHEMA = DATABASE()" + if strings.TrimSpace(dbName) != "" { + schemaPredicate = fmt.Sprintf("TABLE_SCHEMA = %s", starRocksMetadataLiteral(strings.TrimSpace(dbName))) + } + + return fmt.Sprintf(`SELECT + COLUMN_NAME, + COLUMN_TYPE, + IS_NULLABLE, + COLUMN_KEY, + COLUMN_DEFAULT, + EXTRA, + COLUMN_COMMENT +FROM information_schema.columns +WHERE %s AND TABLE_NAME = %s +ORDER BY ORDINAL_POSITION`, schemaPredicate, starRocksMetadataLiteral(strings.TrimSpace(tableName))) +} + +func getStarRocksRowValue(row map[string]interface{}, keys ...string) (interface{}, bool) { + if len(row) == 0 { + return nil, false + } + for _, key := range keys { + for k, v := range row { + if !strings.EqualFold(strings.TrimSpace(k), strings.TrimSpace(key)) { + continue + } + return v, true + } + } + return nil, false +} + +func getStarRocksRowString(row map[string]interface{}, keys ...string) string { + v, ok := getStarRocksRowValue(row, keys...) + if !ok || v == nil { + return "" + } + text := strings.TrimSpace(fmt.Sprintf("%v", v)) + if text == "" || strings.EqualFold(text, "") { + return "" + } + return text +} + +func buildStarRocksColumnDefinitions(data []map[string]interface{}) []connection.ColumnDefinition { + columns := make([]connection.ColumnDefinition, 0, len(data)) + for _, row := range data { + col := connection.ColumnDefinition{ + Name: getStarRocksRowString(row, "Field", "COLUMN_NAME"), + Type: getStarRocksRowString(row, "Type", "COLUMN_TYPE"), + Nullable: getStarRocksRowString(row, "Null", "IS_NULLABLE"), + Key: strings.ToUpper(getStarRocksRowString(row, "Key", "COLUMN_KEY")), + Extra: getStarRocksRowString(row, "Extra", "EXTRA"), + Comment: getStarRocksRowString(row, "Comment", "COLUMN_COMMENT"), + } + + if rawDefault, ok := getStarRocksRowValue(row, "Default", "COLUMN_DEFAULT"); ok && rawDefault != nil { + def := fmt.Sprintf("%v", rawDefault) + if strings.EqualFold(def, "") { + def = "" + } + col.Default = &def + } + + columns = append(columns, col) + } + return columns +} + func (s *StarRocksDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" @@ -159,6 +234,14 @@ func resolveStarRocksCredential(config connection.ConnectionConfig, addressIndex return config.User, primaryPassword } +func (s *StarRocksDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + data, _, err := s.Query(buildStarRocksColumnsQuery(dbName, tableName)) + if err != nil { + return nil, err + } + return buildStarRocksColumnDefinitions(data), nil +} + func (s *StarRocksDB) Connect(config connection.ConnectionConfig) error { runConfig := applyStarRocksURI(config) addresses := collectStarRocksAddresses(runConfig) diff --git a/internal/db/starrocks_metadata_test.go b/internal/db/starrocks_metadata_test.go new file mode 100644 index 0000000..209c925 --- /dev/null +++ b/internal/db/starrocks_metadata_test.go @@ -0,0 +1,81 @@ +//go:build gonavi_full_drivers || gonavi_starrocks_driver + +package db + +import ( + "strings" + "testing" +) + +func TestBuildStarRocksColumnsQuery_UsesInformationSchemaColumnKey(t *testing.T) { + t.Parallel() + + query := buildStarRocksColumnsQuery("test_db", "cross_border_erp_erp_sales_order") + + if !strings.Contains(query, "FROM information_schema.columns") { + t.Fatalf("StarRocks columns query should use information_schema.columns, got=%s", query) + } + if !strings.Contains(query, "COLUMN_KEY") { + t.Fatalf("StarRocks columns query should expose COLUMN_KEY as Key, got=%s", query) + } + if !strings.Contains(query, "TABLE_SCHEMA = 'test_db'") { + t.Fatalf("StarRocks columns query should filter by schema, got=%s", query) + } + if !strings.Contains(query, "TABLE_NAME = 'cross_border_erp_erp_sales_order'") { + t.Fatalf("StarRocks columns query should filter by table name, got=%s", query) + } +} + +func TestBuildStarRocksColumnsQuery_UsesCurrentDatabaseWhenDbNameEmpty(t *testing.T) { + t.Parallel() + + query := buildStarRocksColumnsQuery("", "orders") + + if !strings.Contains(query, "TABLE_SCHEMA = DATABASE()") { + t.Fatalf("StarRocks columns query should fall back to current database, got=%s", query) + } + if !strings.Contains(query, "TABLE_NAME = 'orders'") { + t.Fatalf("StarRocks columns query should filter by table name, got=%s", query) + } +} + +func TestBuildStarRocksColumnDefinitions_MarksPrimaryKeyColumns(t *testing.T) { + t.Parallel() + + columns := buildStarRocksColumnDefinitions([]map[string]interface{}{ + { + "Field": "id", + "Type": "bigint", + "Null": "NO", + "Key": "pri", + "Default": nil, + "Extra": "", + "Comment": "订单ID", + }, + { + "Field": "order_no", + "Type": "varchar(64)", + "Null": "YES", + "Key": "", + "Default": "", + "Extra": "", + "Comment": "订单号", + }, + }) + + if len(columns) != 2 { + t.Fatalf("unexpected column count: %d", len(columns)) + } + if columns[0].Name != "id" || columns[0].Key != "PRI" { + t.Fatalf("StarRocks primary key column was not marked as PRI: %+v", columns[0]) + } + if columns[1].Name != "order_no" || columns[1].Key != "" { + t.Fatalf("StarRocks non-primary column key should stay empty: %+v", columns[1]) + } + if columns[0].Default != nil { + t.Fatalf("nil default should remain nil: %+v", columns[0]) + } + if columns[1].Default == nil || *columns[1].Default != "" { + t.Fatalf("empty string default should be preserved: %+v", columns[1]) + } +}