Testing Spark Dataframe transforms is essential and can be accomplished in a more reusable manner. The way, I generally accomplish that is to
- Read the expected and test Dataframe, and
- Invoke the desired transform, and
- Calculate the difference between dataframes. The only caveat in calculating the difference is that in built except function is not sufficient for columns with decimal column types and that requires a bit of work.
To accomplish generic dataframe comparison:
- We need to look at the type of the column and when its numeric,
- Convert it to the corresponding java type and then do decimal comparisons , while allowing for custom precision mismatches. Otherwise,
- Just use the except clause for other column comparisons.
Comparison Code
def compareDF(result: Dataset[Row], expected: Dataset[Row]): Unit = {
val expectedSchemaMap = expected.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
val resSchemaMap = result.schema.map(sf => (sf.name, sf.dataType)).toMap[String, DataType]
_ match {
case (name: String, dType: NumericType) =>
assert(compareNumericTypes(result, expected, resSchemaMap(name), dType, name), s"$name column was not equal")
case kv: Map[_, _] =>
assert(result.select(kv._1).except(result.select(kv._1)).count() == 0, s"${kv._1} column was not equal")
}
}
def compareNumericTypes(result: Dataset[Row], expected: Dataset[Row], resType: DataType, expType: DataType, colName: String, precision: Double = 0.01): Boolean = {
//collect Results
val res = extractAndSortNumericRow(result, colName, resType)
val exp = extractAndSortNumericRow(expected, colName, expType)
//compare lengths first
if (res.length != exp.length) return false
res match {
case Seq(_: java.lang.Integer, _*) | Seq(_: java.lang.Long, _*) =>
!res.zip(exp).exists(zipped => (safelyGet(zipped._1).longValue() - safelyGet(zipped._2).longValue()) != 0L)
case Seq(_: java.lang.Float, _*) | Seq(_: java.lang.Double, _*) =>
!res.zip(exp).exists(zipped => (safelyGet(zipped._1).doubleValue() - safelyGet(zipped._2).doubleValue()).abs >= precision)
}
}
//upcast types
def safelyGet[T >: Number](v: T): T = {
v match {
case _: java.lang.Long | _: java.lang.Integer => java.lang.Long.parseLong(v.toString)
case _: java.lang.Float | _: java.lang.Double =>
java.lang.Double.parseDouble(v.toString)
case _ => v
}
}
//map internal spark types to java types.
def extractAndSortNumericRow[T <: NumericType](df: Dataset[Row], colName: String, dt: T): Seq[Number] = {
import ss.implicits._
dt match {
case _: LongType => df.select(colName).map(row => row.getAs[java.lang.Long](0)).sort().collect()
case _: IntegerType => df.select(colName).map(row => row.getAs[java.lang.Integer](0)).sort().collect()
case _: DoubleType => df.select(colName).map(row => row.getAs[java.lang.Double](0)).sort().collect()
case _: FloatType => df.select(colName).map(row => row.getAs[java.lang.Float](0)).sort().collect()
case _: DecimalType => df.select(colName).map(row => row.getAs[java.math.BigDecimal](0)).sort().collect()
}
}
The code above does the heavylifting for doing comparisons for dataframes. Now all we need is a simple function that invokes the transforms and some simple scalatest testing code showing all this in action.
Function that invokes the transform and does comparison:
def invokeAndCompare(testFileName: String, expectedFileName: String, func: Dataset[Row] => Dataset[Row]): Unit = {
val df = readJsonDF(testFileName)
val expected = readJsonDF(expectedFileName)
val transformResult = func(df)
compareDF(transformResult, expected)
}
def readJsonDF(fileName: String): Dataset[Row] = {
ss.read.json(fileName)
}
Testing Code
Just utilize ScalaTest. Here is how a test looks like for your transforms.
class RandomTransformsTest extends FlatSpec with Matchers with BeforeAndAfter {
after {
//close spark session
ss.close()
}
before {
val ss = SparkSession.builder().master("local[*]").getOrCreate()
}
"testRandomTransform" should "give correct output for input dataframe" in {
val testFileLoc = ""
val expectedFileLoc = ""
//just get the function definition, it will be invoked by invokeAndCompare with the dataframe later on.
val func = RandomTransforms.someRandomFunc() _
SomeObject.invokeAndCompare(testFileLoc, expectedFileLoc, func)
}
}
Wrap Up:
So, there we go, testing made easy for Spark dataframes. It requires some tedious mapping for decimal numbers, but once developed, tests are easy to write for all your dataframe transforms.
comments powered by Disqus