Testing Spark Dataframe transforms is essential and can be accomplished in a more reusable manner. The way, I generally accomplish that is to
- Read the expected and test Dataframe, and
- Invoke the desired transform, and
- Calculate the difference between dataframes. The only caveat in calculating the difference is that in built except function is not sufficient for columns with decimal column types and that requires a bit of work.
To accomplish generic dataframe comparison:
- We need to look at the type of the column and when its numeric,
- Convert it to the corresponding java type and then do decimal comparisons , while allowing for custom precision mismatches. Otherwise,
- Just use the except clause for other column comparisons.
Comparison Code
def compareDF(result: Dataset[Row], expected: Dataset[Row]): Unit = {
val expectedSchemaMap = expected.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
val resSchemaMap = result.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
_ match {
case (name: String, dType: NumericType) =>
assert(compareNumericTypes(result, expected, resSchemaMap(name), dType, name), s"$name column was not equal")
case kv: Map[_, _] =>
assert(result.select(kv._1).except(result.select(kv._1)).count() == 0, s"${kv._1} column was not equal")
}
}
def compareNumericTypes(result: Dataset[Row], expected: Dataset[Row], resType: DataType, expType: DataType, colName: String, precision: Double = 0.01): Boolean = {
//collect Results
val res = extractAndSortNumericRow(result, colName, resType)
val exp = extractAndSortNumericRow(expected, colName, expType)
//compare lengths first
if (res.length != exp.length) return false
res match {
case Seq(_: java.lang.Integer, _*) | Seq(_: java.lang.Long, _*) =>
!res.zip(exp).exists(zipped => (safelyGet(zipped._1).longValue() - safelyGet(zipped._2).longValue()) != 0L)
case Seq(_: java.lang.Float, _*) | Seq(_: java.lang.Double, _*) =>
!res.zip(exp).exists(zipped => (safelyGet(zipped._1).doubleValue() - safelyGet(zipped._2).doubleValue()).abs >= precision)
}
}
//upcast types
def safelyGet[T >: Number](v: T): T = {
v match {
case _: java.lang.Long | _: java.lang.Integer => java.lang.Long.parseLong(v.toString)
case _: java.lang.Float | _: java.lang.Double =>
java.lang.Double.parseDouble(v.toString)
case _ => v
}
}
//map internal spark types to java types.
def extractAndSortNumericRow[T <: span=""> NumericType](df: Dataset[Row], colName: String, dt: T): Seq[Number] = {
import ss.implicits._
dt match {
case _: LongType => df.select(colName).map(row => row.getAs[java.lang.Long](0)).sort().collect()
case _: IntegerType => df.select(colName).map(row => row.getAs[java.lang.Integer](0)).sort().collect()
case _: DoubleType => df.select(colName).map(row => row.getAs[java.lang.Double](0)).sort().collect()
case _: FloatType => df.select(colName).map(row => row.getAs[java.lang.Float](0)).sort().collect()
case _: DecimalType => df.select(colName).map(row => row.getAs[java.math.BigDecimal](0)).sort().collect()
}
}
The code above does the heavylifting for doing comparisons for dataframes. Now all we need is a simple function that invokes the transforms and some simple scalatest testing code showing all this in action.
Function that invokes the transform and does comparison:
def invokeAndCompare(testFileName: String, expectedFileName: String, func: Dataset[Row] => Dataset[Row]): Unit = {
val df = readJsonDF(testFileName)
val expected = readJsonDF(expectedFileName)
val transformResult = func(df)
compareDF(transformResult, expected)
}
def readJsonDF(fileName: String): Dataset[Row] = {
ss.read.json(fileName)
}
Testing Code
Just utilize ScalaTest. Here is how a test looks like for your transforms.
class RandomTransformsTest extends FlatSpec with Matchers with BeforeAndAfter {
after {
//close spark session
ss.close()
}
before {
val ss = SparkSession.builder().master("local[*]").getOrCreate()
}
"testRandomTransform" should "give correct output for input dataframe" in {
val testFileLoc = ""
val expectedFileLoc = ""
//just get the function definition, it will be invoked by invokeAndCompare with the dataframe later on.
val func = RandomTransforms.someRandomFunc() _
SomeObject.invokeAndCompare(testFileLoc, expectedFileLoc, func)
}
}
Wrap Up:
So, there we go, testing made easy for Spark dataframes. It requires some tedious mapping for decimal numbers, but once developed, tests are easy to write for all your dataframe transforms.