Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relative tolerance #215

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
lazy val root = (project in file("."))
.aggregate(core, kafka_0_8)
.aggregate(core)
.settings(noPublishSettings, commonSettings)

val sparkVersion = settingKey[String]("Spark version")
Expand All @@ -14,7 +14,7 @@ lazy val core = (project in file("core"))
coreTestSources,
crossScalaVersions := {
if (sparkVersion.value >= "2.4.0") {
Seq("2.12.8")
Seq("2.12.10", "2.11.12")
} else if (sparkVersion.value >= "2.3.0") {
Seq("2.11.11")
} else {
Expand All @@ -32,6 +32,7 @@ lazy val core = (project in file("core"))
) ++ commonDependencies ++ miniClusterDependencies
)

/*
lazy val kafka_0_8 = {
Project("kafka_0_8", file("kafka-0.8"))
.dependsOn(core)
Expand Down Expand Up @@ -64,6 +65,7 @@ lazy val kafka_0_8 = {
)
)
}
*/

val commonSettings = Seq(
organization := "com.holdenkarau",
Expand Down Expand Up @@ -211,9 +213,10 @@ val coreTestSources = unmanagedSourceDirectories in Test := {

// additional libraries
lazy val commonDependencies = Seq(
"org.scalatest" %% "scalatest" % "3.0.5",
"org.scalatest" %% "scalatest" % "3.1.0",
"org.scalatestplus" %% "scalatestplus-scalacheck" % "3.1.0.0-RC2",
"io.github.nicolasstucki" %% "multisets" % "0.4",
"org.scalacheck" %% "scalacheck" % "1.14.0",
"org.scalacheck" %% "scalacheck" % "1.14.3",
"junit" % "junit" % "4.12",
"org.eclipse.jetty" % "jetty-util" % "9.3.11.v20160721",
"com.novocode" % "junit-interface" % "0.11" % "test->default")
Expand All @@ -231,7 +234,7 @@ def excludeJpountz(items: Seq[ModuleID]) =

libraryDependencies ++= excludeJpountz(
// For Spark 2.4 w/ Scala 2.12 we're going to need some special logic
if (sparkVersion.value >= "2.3.0") {
if (sparkVersion.value >= "2.3.0" && scalaVersion.value < "2.12.0") {
Seq(
"org.apache.spark" %% "spark-streaming-kafka-0-8" % sparkVersion.value
)
Expand Down Expand Up @@ -290,3 +293,7 @@ lazy val publishSettings = Seq(

lazy val noPublishSettings =
skip in publish := true

scalafixDependencies in ThisBuild += "org.scalatest" %% "autofix" % "3.1.0.0"

addCompilerPlugin(scalafixSemanticdb)
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider
* @param tol max acceptable tolerance, should be less than 1.
*/
def assertDataFrameApproximateEquals(
expected: DataFrame, result: DataFrame, tol: Double) {
expected: DataFrame,
result: DataFrame,
tol: Double = 0.0,
relTol: Double = 0.0
) {

assert(expected.schema, result.schema)

Expand All @@ -142,7 +146,7 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider

val unequalRDD = expectedIndexValue.join(resultIndexValue).
filter{case (idx, (r1, r2)) =>
!(r1.equals(r2) || DataFrameSuiteBase.approxEquals(r1, r2, tol))}
!(r1.equals(r2) || DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol))}

assertEmpty(unequalRDD.take(maxUnequalRowsToShow))
} finally {
Expand All @@ -160,15 +164,67 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider
rdd.zipWithIndex().map{ case (row, idx) => (idx, row) }
}

def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
DataFrameSuiteBase.approxEquals(r1, r2, tol)
def approxEquals(
r1: Row,
r2: Row,
tol: Double = 0.0,
relTol: Double = 0.0
): Boolean = {
DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol)
}
}

object DataFrameSuiteBase {
trait WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean
def apply(a: BigDecimal, b: BigDecimal): Boolean
}
object WithinToleranceChecker {
def apply(tol: Double = 0.0, relTol: Double = 0.0) =
if(tol != 0.0 || relTol == 0.0) {
new WithinAbsoluteToleranceChecker(tol)
} else {
new WithinRelativeToleranceChecker(relTol)
}
}

class WithinAbsoluteToleranceChecker(tolerance: Double)
extends WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean =
(a - b).abs <= tolerance
def apply(a: BigDecimal, b: BigDecimal): Boolean =
(a - b).abs <= tolerance
}

class WithinRelativeToleranceChecker(relativeTolerance: Double)
extends WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean = {
val max = (a.abs max b.abs)
if (max == 0.0) {
true
} else {
(a - b).abs / max <= relativeTolerance
}
}
def apply(a: BigDecimal, b: BigDecimal): Boolean = {
val max = (a.abs max b.abs)
if (max == 0.0) {
true
} else {
(a - b).abs / max <= relativeTolerance
}
}
}

/** Approximate equality, based on equals from [[Row]] */
def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
def approxEquals(
r1: Row,
r2: Row,
tol: Double = 0.0,
relTol: Double = 0.0
): Boolean = {
val withinTolerance = WithinToleranceChecker(tol, relTol)

if (r1.length != r2.length) {
return false
} else {
Expand All @@ -192,7 +248,7 @@ object DataFrameSuiteBase {
{
return false
}
if (abs(f1 - o2.asInstanceOf[Float]) > tol) {
if (!withinTolerance(f1, o2.asInstanceOf[Float])) {
return false
}

Expand All @@ -202,12 +258,20 @@ object DataFrameSuiteBase {
{
return false
}
if (abs(d1 - o2.asInstanceOf[Double]) > tol) {
if (!withinTolerance(d1, o2.asInstanceOf[Double])) {
return false
}

case d1: java.math.BigDecimal =>
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) {
if (!withinTolerance(
BigDecimal(d1),
BigDecimal(o2.asInstanceOf[java.math.BigDecimal]
))) {
return false
}

case d1: scala.math.BigDecimal =>
if (!withinTolerance(d1, o2.asInstanceOf[scala.math.BigDecimal])) {
return false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ object DataFrameSuiteBase {
if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false

case d1: java.math.BigDecimal =>
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false
if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs
.compareTo(new java.math.BigDecimal(tol)) > 0) return false

case d1: scala.math.BigDecimal =>
if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) return false

case t1: Timestamp =>
if (abs(t1.getTime - o2.asInstanceOf[Timestamp].getTime) > tol) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._

import org.scalatest.FunSuite
import org.scalatest.exceptions.TestFailedException
import org.scalatest.funsuite.AnyFunSuite

/**
* ArtisinalStreamingTest illustrates how to write a streaming test
Expand All @@ -36,7 +36,7 @@ import org.scalatest.exceptions.TestFailedException
* This does not use a manual clock and instead uses the kind of sketchy
* sleep approach. Instead please look at [[SampleStreamingTest]].
*/
class ArtisinalStreamingTest extends FunSuite with SharedSparkContext {
class ArtisinalStreamingTest extends AnyFunSuite with SharedSparkContext {
// tag::createQueueStream[]
def makeSimpleQueueStream(ssc: StreamingContext) = {
val input = List(List("hi"), List("happy pandas", "sad pandas"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import com.holdenkarau.spark.testing.{RDDComparisons, SharedSparkContext}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite

class HDFSClusterTest extends FunSuite with SharedSparkContext with RDDComparisons {
class HDFSClusterTest extends AnyFunSuite with SharedSparkContext with RDDComparisons {

var hdfsCluster: HDFSCluster = null

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.holdenkarau.spark.testing

import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite

class MultipleDataFrameSuites extends FunSuite with DataFrameSuiteBase {
class MultipleDataFrameSuites extends AnyFunSuite with DataFrameSuiteBase {
test("test nothing") {
assert(1 === 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
*/
package com.holdenkarau.spark.testing

import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite

/**
* Illustrate using per-test sample test. This is the one to use
* when your tests may be destructive to the Spark context
* (e.g. stopping it)
*/
class PerTestSampleTest extends FunSuite with PerTestSparkContext {
class PerTestSampleTest extends AnyFunSuite with PerTestSparkContext {

test("sample test stops a context") {
sc.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package com.holdenkarau.spark.testing
import java.nio.file.Files

import org.apache.spark._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite

/**
* Illustrate using per-test sample test. This is the one to use
* when your tests may be destructive to the Spark context
* (e.g. stopping it)
*/
class PerfSampleTest extends FunSuite with PerTestSparkContext {
class PerfSampleTest extends AnyFunSuite with PerTestSparkContext {
val tempPath = Files.createTempDirectory(null).toString()

//tag::samplePerfTest[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ import java.sql.Timestamp

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import java.math.{ BigDecimal => JBigDecimal }
import org.scalatest.funsuite.AnyFunSuite

class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
class SampleDataFrameTest extends AnyFunSuite with DataFrameSuiteBase {
val byteArray = new Array[Byte](1)
val diffByteArray = Array[Byte](192.toByte)
val inputList = List(
Expand Down Expand Up @@ -70,6 +71,10 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
val row8 = Row(Timestamp.valueOf("2018-01-12 20:22:13"))
val row9 = Row(Timestamp.valueOf("2018-01-12 20:22:18"))
val row10 = Row(Timestamp.valueOf("2018-01-12 20:23:13"))
val row11 = Row(new JBigDecimal(1.0))
val row11a = Row(new JBigDecimal(1.0 + 1.0E-6))
val row12 = Row(BigDecimal(1.0))
val row12a = Row(BigDecimal(1.0 + 1.0E-6))
assert(false === approxEquals(row, row2, 1E-7))
assert(true === approxEquals(row, row2, 1E-5))
assert(true === approxEquals(row3, row3, 1E-5))
Expand All @@ -84,6 +89,44 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
assert(false === approxEquals(row9, row8, 3000))
assert(true === approxEquals(row9, row10, 60000))
assert(false === approxEquals(row9, row10, 53000))
assert(true === approxEquals(row11, row11a, 1.0E-6))
assert(true === approxEquals(row12, row12a, 1.0E-6))
}

test("dataframe approxEquals on rows with relative tolerance") {
import sqlContext.implicits._
// Use 1 / 2^n as example numbers to avoid numeric errors
val relTol = scala.math.pow(2, -6)
val orig = 0.25
val within = orig - relTol * orig
val outside = within - 1.0E-4
def assertRelativeApproxEqualsWorksFor[T](constructor: Double => T) = {
assertResult(true) {
approxEquals(
Row(constructor(orig)),
Row(constructor(within)),
relTol = relTol
)
}
assertResult(false) {
approxEquals(
Row(constructor(orig)),
Row(constructor(outside)),
relTol = relTol
)
}
assertResult(true) {
approxEquals(
Row(constructor(0.0)),
Row(constructor(0.0)),
relTol = relTol
)
}
}
assertRelativeApproxEqualsWorksFor[Double](identity)
assertRelativeApproxEqualsWorksFor[Float](_.toFloat)
assertRelativeApproxEqualsWorksFor[BigDecimal](BigDecimal.apply)
assertRelativeApproxEqualsWorksFor[JBigDecimal](new JBigDecimal(_))
}

test("verify hive function support") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package com.holdenkarau.spark.testing
import scala.util.Random

import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite

class SampleRDDTest extends FunSuite with SharedSparkContext with RDDComparisons {
class SampleRDDTest extends AnyFunSuite with SharedSparkContext with RDDComparisons {
test("really simple transformation") {
val input = List("hi", "hi holden", "bye")
val expected = List(List("hi"), List("hi", "holden"), List("bye"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types._
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck.Prop.forAll
import org.scalatest.FunSuite
import org.scalatest.prop.Checkers
import org.scalatestplus.scalacheck.Checkers
import org.scalatest.funsuite.AnyFunSuite

class SampleScalaCheckTest extends FunSuite
class SampleScalaCheckTest extends AnyFunSuite
with SharedSparkContext with RDDComparisons with Checkers {
// tag::propertySample[]
// A trivial property that the map doesn't change the number of elements
Expand Down Expand Up @@ -257,7 +257,7 @@ class SampleScalaCheckTest extends FunSuite

test("generate rdd of specific size") {
implicit val generatorDrivenConfig =
PropertyCheckConfig(minSize = 10, maxSize = 20)
PropertyCheckConfiguration(minSize = 10, sizeRange = 20)
val prop = forAll(RDDGenerator.genRDD[String](sc)(Arbitrary.arbitrary[String])){
rdd => rdd.count() <= 20
}
Expand Down Expand Up @@ -333,7 +333,7 @@ class SampleScalaCheckTest extends FunSuite
StructType(StructField("timestampType", TimestampType) :: Nil)) :: Nil
test("second dataframe's evaluation has the same values as first") {
implicit val generatorDrivenConfig =
PropertyCheckConfig(minSize = 1, maxSize = 1)
PropertyCheckConfiguration(minSize = 1, sizeRange = 1)

val sqlContext = new SQLContext(sc)
val dataframeGen =
Expand All @@ -354,7 +354,7 @@ class SampleScalaCheckTest extends FunSuite
}
test("nullable fields contain null values as well") {
implicit val generatorDrivenConfig =
PropertyCheckConfig(minSize = 1, maxSize = 1)
PropertyCheckConfiguration(minSize = 1, sizeRange = 1)
val nullableFields = fields.map(f => f.copy(nullable = true, name = s"${f.name}Nullable"))
val sqlContext = new SQLContext(sc)
val allFields = fields ::: nullableFields
Expand Down
Loading