This is an extension of my post on Transforming Spark Datasets using Scala transformation functions.
In the previous post we de-serialized a Spark Dataset
to a scala case class
and learnt how to use Encoders
to run transformations over the Dataset
. In this post, we’ll explore how to transform a DataFrame
using a User Defined Function - udf
.
Expected Results
As with the previous post, this is the input DataFrame
and the expected output DataFrame
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// Input DataFrame
+--------+----------+----------+
|flightNo| departure| arrival|
+--------+----------+----------+
| GOI-134|1569315038|1569319183|
| AI-123|1569290498|1569298178|
| TK-006|1567318178|1567351838|
+--------+----------+----------+
// Output DataFrame
+--------+--------+
|flightNo|duration|
+--------+--------+
| GOI-134| 1 hrs|
| AI-123| 2 hrs|
| TK-006| 9 hrs|
+--------+--------+
Let’s create our DataFrame
Spark defines a DataFrame
as type DataFrame = Dataset[Row]
, in essence it’s a Dataset
of a generic Row
.
1
2
3
4
5
6
import spark.implicits._
val schedules = Seq(
("GOI-134", 1569315038, 1569319183),
("AI-123", 1569290498, 1569298178),
("TK-006", 1567318178, 1567351838)
).toDF("flightNo", "departure", "arrival")
Define the UDF
We have to define our udf
as a variable so that that too can be passed to functions. For this, we’ll need to import org.apache.spark.sql.functions.udf
. Exactly like the previous post, our function will accept two Long
parameters i.e. the Departure time and the Arrival time and return a String
i.e. the duration of the flight.
1
2
3
4
5
import org.apache.spark.sql.functions.udf
val getDurationInHours = udf((arrival: Long, departure: Long) => {
val duration = (arrival - departure) / 60 / 60
s"$duration hrs"
})
Transform the DataFrame
Now all that’s left is to transform the DataFrame
. We’ll do this by calling the select
function with the flightNo
column and the udf
with an alias of “duration”.
1
2
3
4
5
import org.apache.spark.sql.functions.col
val flightInfo = schedules
.select(
col("flightNo"),
getDurationInHours(col("arrival"), col("departure")) as "duration")
Source Code
Here is the entire source code.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object DataFrameTransform {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("transform").getOrCreate()
import spark.implicits._
val schedules = Seq(
("GOI-134", 1569315038, 1569319183),
("AI-123", 1569290498, 1569298178),
("TK-006", 1567318178, 1567351838)
).toDF("flightNo", "departure", "arrival")
// Print the input DataFrame
schedules.show()
// Defining the User Defined Function UDF
val getDurationInHours = udf((arrival: Long, departure: Long) => {
val duration = (arrival - departure) / 60 / 60
s"$duration hrs"
})
// Transform DataFrame
val flightInfo = schedules
.select(
col("flightNo"),
getDurationInHours(col("arrival"), col("departure")) as "duration")
// Print the output DataFrame
flightInfo.show()
}
}