diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 9efb1b6..6f8947f 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -263,16 +263,31 @@ func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error) } func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { - query := fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default - FROM all_tab_columns - WHERE owner = '%s' AND table_name = '%s' - ORDER BY column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName)) + query := fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default, + CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key + FROM all_tab_columns c + LEFT JOIN ( + SELECT cols.owner, cols.table_name, cols.column_name + FROM all_constraints cons + JOIN all_cons_columns cols + ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name + WHERE cons.constraint_type = 'P' + ) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name + WHERE c.owner = '%s' AND c.table_name = '%s' + ORDER BY c.column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName)) if dbName == "" { - query = fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default - FROM user_tab_columns - WHERE table_name = '%s' - ORDER BY column_id`, strings.ToUpper(tableName)) + query = fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default, + CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key + FROM user_tab_columns c + LEFT JOIN ( + SELECT cols.table_name, cols.column_name + FROM user_constraints cons + JOIN user_cons_columns cols USING (constraint_name) + WHERE cons.constraint_type = 'P' + ) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name + WHERE c.table_name = '%s' + ORDER BY c.column_id`, strings.ToUpper(tableName)) } data, _, err := o.Query(query) @@ -286,6 +301,7 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi Name: fmt.Sprintf("%v", row["COLUMN_NAME"]), Type: fmt.Sprintf("%v", row["DATA_TYPE"]), Nullable: fmt.Sprintf("%v", row["NULLABLE"]), + Key: fmt.Sprintf("%v", row["COLUMN_KEY"]), } if row["DATA_DEFAULT"] != nil {