From 88c71afe1e4e7e26838128b5bc3f109af1f04f8a Mon Sep 17 00:00:00 2001 From: Martin Mauch Date: Thu, 9 Nov 2017 13:35:51 +0100 Subject: [PATCH] Add relative tolerance checks to approxEquals Fixes #214 --- .../spark/testing/DataFrameSuiteBase.scala | 78 ++++++++++++++++--- .../spark/testing/SampleDataFrameTest.scala | 45 ++++++++++- 2 files changed, 109 insertions(+), 14 deletions(-) diff --git a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 3c3f0a99..2f97df8e 100644 --- a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -129,7 +129,11 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider * @param tol max acceptable tolerance, should be less than 1. */ def assertDataFrameApproximateEquals( - expected: DataFrame, result: DataFrame, tol: Double) { + expected: DataFrame, + result: DataFrame, + tol: Double = 0.0, + relTol: Double = 0.0 + ) { assert(expected.schema, result.schema) @@ -143,7 +147,7 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider val unequalRDD = expectedIndexValue.join(resultIndexValue). filter{case (idx, (r1, r2)) => - !(r1.equals(r2) || DataFrameSuiteBase.approxEquals(r1, r2, tol))} + !(r1.equals(r2) || DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol))} assertEmpty(unequalRDD.take(maxUnequalRowsToShow)) } finally { @@ -161,15 +165,67 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider rdd.zipWithIndex().map{ case (row, idx) => (idx, row) } } - def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = { - DataFrameSuiteBase.approxEquals(r1, r2, tol) + def approxEquals( + r1: Row, + r2: Row, + tol: Double = 0.0, + relTol: Double = 0.0 + ): Boolean = { + DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol) } } object DataFrameSuiteBase { + trait WithinToleranceChecker { + def apply(a: Double, b: Double): Boolean + def apply(a: BigDecimal, b: BigDecimal): Boolean + } + object WithinToleranceChecker { + def apply(tol: Double = 0.0, relTol: Double = 0.0) = + if(tol != 0.0 || relTol == 0.0) { + new WithinAbsoluteToleranceChecker(tol) + } else { + new WithinRelativeToleranceChecker(relTol) + } + } + + class WithinAbsoluteToleranceChecker(tolerance: Double) + extends WithinToleranceChecker { + def apply(a: Double, b: Double): Boolean = + (a - b).abs <= tolerance + def apply(a: BigDecimal, b: BigDecimal): Boolean = + (a - b).abs <= tolerance + } + + class WithinRelativeToleranceChecker(relativeTolerance: Double) + extends WithinToleranceChecker { + def apply(a: Double, b: Double): Boolean = { + val max = (a.abs max b.abs) + if (max == 0.0) { + true + } else { + (a - b).abs / max <= relativeTolerance + } + } + def apply(a: BigDecimal, b: BigDecimal): Boolean = { + val max = (a.abs max b.abs) + if (max == 0.0) { + true + } else { + (a - b).abs / max <= relativeTolerance + } + } + } /** Approximate equality, based on equals from [[Row]] */ - def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = { + def approxEquals( + r1: Row, + r2: Row, + tol: Double = 0.0, + relTol: Double = 0.0 + ): Boolean = { + val withinTolerance = WithinToleranceChecker(tol, relTol) + if (r1.length != r2.length) { return false } else { @@ -193,7 +249,7 @@ object DataFrameSuiteBase { { return false } - if (abs(f1 - o2.asInstanceOf[Float]) > tol) { + if (!withinTolerance(f1, o2.asInstanceOf[Float])) { return false } @@ -203,18 +259,20 @@ object DataFrameSuiteBase { { return false } - if (abs(d1 - o2.asInstanceOf[Double]) > tol) { + if (!withinTolerance(d1, o2.asInstanceOf[Double])) { return false } case d1: java.math.BigDecimal => - if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs - .compareTo(new java.math.BigDecimal(tol)) > 0) { + if (!withinTolerance( + BigDecimal(d1), + BigDecimal(o2.asInstanceOf[java.math.BigDecimal] + ))) { return false } case d1: scala.math.BigDecimal => - if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) { + if (!withinTolerance(d1, o2.asInstanceOf[scala.math.BigDecimal])) { return false } diff --git a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala index 1ca7faae..4e6ec4b4 100644 --- a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala +++ b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.scalatest.FunSuite +import java.math.{ BigDecimal => JBigDecimal } class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { val byteArray = new Array[Byte](1) @@ -70,10 +71,10 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { val row8 = Row(Timestamp.valueOf("2018-01-12 20:22:13")) val row9 = Row(Timestamp.valueOf("2018-01-12 20:22:18")) val row10 = Row(Timestamp.valueOf("2018-01-12 20:23:13")) - val row11 = Row(new java.math.BigDecimal(1.0)) - val row11a = Row(new java.math.BigDecimal(1.0 + 1.0E-6)) - val row12 = Row(scala.math.BigDecimal(1.0)) - val row12a = Row(scala.math.BigDecimal(1.0 + 1.0E-6)) + val row11 = Row(new JBigDecimal(1.0)) + val row11a = Row(new JBigDecimal(1.0 + 1.0E-6)) + val row12 = Row(BigDecimal(1.0)) + val row12a = Row(BigDecimal(1.0 + 1.0E-6)) assert(false === approxEquals(row, row2, 1E-7)) assert(true === approxEquals(row, row2, 1E-5)) assert(true === approxEquals(row3, row3, 1E-5)) @@ -92,6 +93,42 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { assert(true === approxEquals(row12, row12a, 1.0E-6)) } + test("dataframe approxEquals on rows with relative tolerance") { + import sqlContext.implicits._ + // Use 1 / 2^n as example numbers to avoid numeric errors + val relTol = scala.math.pow(2, -6) + val orig = 0.25 + val within = orig - relTol * orig + val outside = within - 1.0E-4 + def assertRelativeApproxEqualsWorksFor[T](constructor: Double => T) = { + assertResult(true) { + approxEquals( + Row(constructor(orig)), + Row(constructor(within)), + relTol = relTol + ) + } + assertResult(false) { + approxEquals( + Row(constructor(orig)), + Row(constructor(outside)), + relTol = relTol + ) + } + assertResult(true) { + approxEquals( + Row(constructor(0.0)), + Row(constructor(0.0)), + relTol = relTol + ) + } + } + assertRelativeApproxEqualsWorksFor[Double](identity) + assertRelativeApproxEqualsWorksFor[Float](_.toFloat) + assertRelativeApproxEqualsWorksFor[BigDecimal](BigDecimal.apply) + assertRelativeApproxEqualsWorksFor[JBigDecimal](new JBigDecimal(_)) + } + test("verify hive function support") { import sqlContext.implicits._ // Convert to int since old versions of hive only support percentile on