Jan 17, 2018

Apache Spark offers the ability to write Generic UDFs. However, for an idiomatic implementation, there are a couple of things that one needs to keep in mind.
  1. You should return a subtype of Option because Spark treats None subtype automatically as null and is able to extract value from Some subtype.
  2. Your Generic UDFs should be able to handle Option or regular type as input. To accomplish this, use type matching in case of Option and recursively extract values. This scenario occurs, if your UDF is in turn wrapped by another UDF.
If these considerations are handled correctly, the implemented UDF has several important benefits:
  • It avoids the code duplication. And,
  • It handles nulls in a more idiomatic way.
Here is an example of a UDF that can be used to calculate the intervals between two time periods.
import java.time.{LocalDate, ZoneId}
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit

import scala.util.Try

def convertToDate[T](date: T): Option[LocalDate] = {
  if (date == null) return None
  date match {
    case dt: LocalDate => Some(dt)
    case dt: String =>
      if (dt.isEmpty) return None
      val retValue = Try {
        LocalDate.parse(dt, DateTimeFormatter.ISO_DATE)
      }.getOrElse(LocalDate.parse(dt, DateTimeFormatter.ISO_LOCAL_DATE_TIME))
    case dt: java.sql.Date => Some(dt.toLocalDate)
    case dt: java.util.Date => Some(dt.toInstant.atZone(ZoneId.systemDefault()).toLocalDate)
    case dt: Option[_] => if (dt.isDefined) convertToDate(dt.get) else None

def interval_between[V1, V2](fromDate: V1, toDate: V2, intType: String): Option[Long] = {
  def calculateInterval(fromDate: LocalDate, toDate: LocalDate, intType: String = "months"): Option[Long] = {
    val returnVal = intType match {
      case "decades" => ChronoUnit.DECADES.between(fromDate, toDate)
      case "years" => ChronoUnit.YEARS.between(fromDate, toDate)
      case "months" => ChronoUnit.MONTHS.between(fromDate, toDate)
      case "days" => ChronoUnit.DAYS.between(fromDate, toDate)
      case "hours" => ChronoUnit.HOURS.between(fromDate, toDate)
      case "minutes" => ChronoUnit.MINUTES.between(fromDate, toDate)
      case "seconds" => ChronoUnit.SECONDS.between(fromDate, toDate)
      case _ => throw new IllegalArgumentException(s"$intType is not supported")

  val fromDt = convertToDate(fromDate)
  val toDt = convertToDate(toDate)
  if (fromDt.isEmpty || toDt.isEmpty) {
    return None
  calculateInterval(fromDt.get, toDt.get, intType.toLowerCase)
The above UDF takes care of the concerns mentioned earlier in the post. To use it, you simply have to register it as a UDF with SparkSession.
 ss.udf.register("interval_between", interval_between _)

Posted on Wednesday, January 17, 2018 by Ramandeep Singh Nanda

Testing Spark Dataframe transforms is essential and can be accomplished in a more reusable manner. The way, I generally accomplish that is to
  • Read the expected and test Dataframe, and
  • Invoke the desired transform, and
  • Calculate the difference between dataframes. The only caveat in calculating the difference is that in built except function is not sufficient for columns with decimal column types and that requires a bit of work.
To accomplish generic dataframe comparison:
  • We need to look at the type of the column and when its numeric,
  • Convert it to the corresponding java type and then do decimal comparisons , while allowing for custom precision mismatches. Otherwise,
  • Just use the except clause for other column comparisons.

Comparison Code

def compareDF(result: Dataset[Row], expected: Dataset[Row]): Unit = {
val expectedSchemaMap = expected.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
val resSchemaMap = result.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
_ match {
  case (name: String, dType: NumericType) =>
    assert(compareNumericTypes(result, expected, resSchemaMap(name), dType, name), s"$name column was not equal")
  case kv: Map[_, _] =>
    assert(result.select(kv._1).except(result.select(kv._1)).count() == 0, s"${kv._1} column was not equal")

def compareNumericTypes(result: Dataset[Row], expected: Dataset[Row], resType: DataType, expType: DataType, colName: String, precision: Double = 0.01): Boolean = {
  //collect Results
  val res = extractAndSortNumericRow(result, colName, resType)
  val exp = extractAndSortNumericRow(expected, colName, expType)
  //compare lengths first
  if (res.length != exp.length) return false
  res match {
    case Seq(_: java.lang.Integer, _*) | Seq(_: java.lang.Long, _*) =>
      !res.zip(exp).exists(zipped => (safelyGet(zipped._1).longValue() - safelyGet(zipped._2).longValue()) != 0L)
    case Seq(_: java.lang.Float, _*) | Seq(_: java.lang.Double, _*) =>
      !res.zip(exp).exists(zipped => (safelyGet(zipped._1).doubleValue() - safelyGet(zipped._2).doubleValue()).abs >= precision)
//upcast types
def safelyGet[T >: Number](v: T): T = {
  v match {
    case _: java.lang.Long | _: java.lang.Integer => java.lang.Long.parseLong(v.toString)
    case _: java.lang.Float | _: java.lang.Double =>
    case _ => v

//map internal spark types to java types.

def extractAndSortNumericRow[T <: span=""> NumericType](df: Dataset[Row], colName: String, dt: T): Seq[Number] = {
  import ss.implicits._
  dt match {
    case _: LongType => df.select(colName).map(row => row.getAs[java.lang.Long](0)).sort().collect()
    case _: IntegerType => df.select(colName).map(row => row.getAs[java.lang.Integer](0)).sort().collect()
    case _: DoubleType => df.select(colName).map(row => row.getAs[java.lang.Double](0)).sort().collect()
    case _: FloatType => df.select(colName).map(row => row.getAs[java.lang.Float](0)).sort().collect()
    case _: DecimalType => df.select(colName).map(row => row.getAs[java.math.BigDecimal](0)).sort().collect()
The code above does the heavylifting for doing comparisons for dataframes. Now all we need is a simple function that invokes the transforms and some simple scalatest testing code showing all this in action.
Function that invokes the transform and does comparison:
def invokeAndCompare(testFileName: String, expectedFileName: String, func: Dataset[Row] => Dataset[Row]): Unit = {
  val df = readJsonDF(testFileName)
  val expected = readJsonDF(expectedFileName)
  val transformResult = func(df)
  compareDF(transformResult, expected)

def readJsonDF(fileName: String): Dataset[Row] = {

Testing Code

Just utilize ScalaTest. Here is how a test looks like for your transforms.
class RandomTransformsTest extends FlatSpec with Matchers with BeforeAndAfter {
  after {
    //close spark session
  before {
    val ss = SparkSession.builder().master("local[*]").getOrCreate()

  "testRandomTransform" should "give correct output for input dataframe" in {
    val testFileLoc = ""
    val expectedFileLoc = ""
    //just get the function definition, it will be invoked by invokeAndCompare with the dataframe later on.
    val func = RandomTransforms.someRandomFunc() _
    SomeObject.invokeAndCompare(testFileLoc, expectedFileLoc, func)


Wrap Up:

So, there we go, testing made easy for Spark dataframes. It requires some tedious mapping for decimal numbers, but once developed, tests are easy to write for all your dataframe transforms.

Posted on Wednesday, January 17, 2018 by Ramandeep Singh Nanda