PySpark SQL Functions | collect_set method
Start your free 7-days trial now!
PySpark SQL Functions' collect_set(~)
method returns a unique set of values in a column. Null values are ignored.
Use collect_list(~)
instead to obtain a list of values that allows for duplicates.
Parameters
1. col
| string
or Column
object
The column label or a Column
object.
Return Value
A PySpark SQL Column
object (pyspark.sql.column.Column
).
Assume that the order of the returned set may be random since the order is affected by shuffle operationslink.
Examples
Consider the following PySpark DataFrame:
data = [("Alex", "A"), ("Alex", "B"), ("Bob", "A"), ("Cathy", "C"), ("Dave", None)]
+-----+-----+| name|group|+-----+-----+| Alex| A|| Alex| B|| Bob| A||Cathy| C|| Dave| null|+-----+-----+
Getting a set of column values in PySpark
To get the unique set of values in the group
column:
Equivalently, you can pass in a Column
object to collect_set(~)
as well:
Notice how the null
value does not appear in the resulting set.
Getting the set as a standard list
To get the set as a standard list:
Here, the PySpark DataFrame's collect()
method returns a list of Row
objects. This list is guaranteed to be length one due to the nature of collect_set(~)
. The Row
object contains the list so we need to include another [0]
.
Getting a set of column values of each group in PySpark
The method collect_set(~)
is often used in the context of aggregation. Consider the same PySpark DataFrame as before:
+-----+-----+| name|group|+-----+-----+| Alex| A|| Alex| B|| Bob| A||Cathy| C|| Dave| null|+-----+-----+
To flatten the group
column into a single set for each name
: