Iterating over each row of a PySpark DataFrame
Start your free 7-days trial now!
Iterating over a PySpark DataFrame is tricky because of its distributed nature - the data of a PySpark DataFrame is typically scattered across multiple worker nodes. This guide explores three solutions for iterating over each row, but I recommend opting for the first solution!
Using the map method of RDD to iterate over the rows of PySpark DataFrame
All Spark DataFrames are internally represented using Spark's built-in data structure called RDD (resilient distributed dataset). One way of iterating over the rows of a PySpark DataFrame is to use the map(~)
function available only to RDDs - we therefore need to convert the PySpark DataFrame into a RDD first.
As an example, consider the following PySpark DataFrame:
+-----+---+| name|age|+-----+---+| Alex| 15|| Bob| 20||Cathy| 25|+-----+---+
We can iterate over each row of this PySpark DataFrame like so:
Here, note the following:
the conversion from PySpark DataFrame to RDD is simple -
df.rdd
.we then use the
map(~)
method of the RDD, which takes in as argument a function. This function takes as input a singleRow
object and is invoked for each row of the PySpark DataFrame.in the first line of our custom function
my_func(~)
, we convert theRow
into a dictionary usingasDict()
. The reason for this is that we cannot mutate theRow
object directly - and so we must convert theRow
object into a dictionary, then perform an update on the dictionary, and then finally convert the updated dictionary back to aRow
object.the
**
inRow(**d)
converts the dictionaryd
into keyword arguments for theRow(~)
constructor.
Unlike the other solutions that will be discussed below, this solution allows us to update the values of each row while we iterate over the rows.
Using the collect method and then iterating in the driver node
Another solution is to use the collect(~)
method to push all the data from the worker nodes to the driver program, and then iterate over the rows.
As an example, consider the following PySpark DataFrame:
+-----+---+| name|age|+-----+---+| Alex| 20|| Bob| 30||Cathy| 40|+-----+---+
We can use the collect(~)
method to first send all the data from the worker nodes to the driver program, and then perform a simple for-loop:
Watch out for the following limitations:
since the
collect(~)
method will send all the data to the driver node, make sure that your driver node has enough memory to avoid an out-of-memory error.we cannot update the value of the rows while we iterate.
Using foreach to iterate over the rows in the worker nodes
The foreach(~)
method instructs the worker nodes in the cluster to iterate over each row (as a Row
object) of a PySpark DataFrame and apply a function on each row on the worker node hosting the row:
Here, the printed results will only be displayed in the standard output of the worker node instead of the driver program.
The following are some hard limitations of foreach(~)
imposed by Spark:
the row is read-only. This means that you cannot update the row values while iterating.
since the worker nodes are performing the iteration and not the driver program, standard output/error will not be shown in our session/notebook. For instance, performing a
print(~)
as we have done in our function will not display the printed results in our session/notebook - instead we would need to check the log of the worker nodes.
Given such limitations, one of the main use case of foreach(~)
is to log - either to a file or an external database - the rows of the PySpark DataFrame.