Skip to content

Commit

Permalink
[Min/Max] Apply filtered row behavior at the row level evaluation
Browse files Browse the repository at this point in the history
- This changes from applying the behavior at the analyzer level. It allows us to prevent the usage of MinValue/MaxValue as placeholder values for filtered rows.
  • Loading branch information
rdsharma26 committed Mar 5, 2024
1 parent c89aad8 commit 42d2425
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 94 deletions.
17 changes: 1 addition & 16 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MaxState] with FullColumn {
Expand All @@ -49,7 +47,7 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {

ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(rowLevelResults))
MaxState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -61,18 +59,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions

@VisibleForTesting
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType)
case _ =>
criterion
}
}

}

20 changes: 2 additions & 18 deletions src/main/scala/com/amazon/deequ/analyzers/Minimum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MinState(minValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MinState] with FullColumn {
Expand All @@ -48,7 +46,7 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {
ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(rowLevelResults))
MinState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -59,19 +57,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = {
conditionalSelection(column, where).cast(DoubleType)
}

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType)
case _ =>
criterion
}
}
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)
}
22 changes: 20 additions & 2 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,9 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MinState, Double, Double](minimum, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertion(assertion, minimum.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
s"MinimumConstraint($minimum)",
Expand Down Expand Up @@ -663,7 +665,9 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maximum, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertion(assertion, maximum.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
s"MaximumConstraint($maximum)",
Expand Down Expand Up @@ -916,6 +920,20 @@ object Constraint {
.getOrElse(0.0)
}

private[this] def getUpdatedRowLevelAssertion(assertion: Double => Boolean,
analyzerOpts: Option[AnalyzerOptions])
: java.lang.Double => java.lang.Boolean = {
(d: java.lang.Double) => {
if (Option(d).isDefined) assertion(d)
else analyzerOpts match {
case Some(analyzerOptions) => analyzerOptions.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
case None => null
}
}
}
}

/**
Expand Down
122 changes: 117 additions & 5 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,10 @@ import org.scalamock.scalatest.MockFactory
import org.scalatest.Matchers
import org.scalatest.WordSpec



class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
with FixtureSupport with MockFactory {

"Verification Suite" should {

"return the correct verification status regardless of the order of checks" in
withSparkSession { sparkSession =>

Expand Down Expand Up @@ -374,11 +371,11 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec

// filtered rows 1, 2, 3 (where item > 3)
val minRowLevel = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, true, true, true, true).sameElements(minRowLevel))
assert(Seq(null, null, null, true, true, true).sameElements(minRowLevel))

// filtered rows 4, 5, 6 (where item < 4)
val maxRowLevel = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, true, true, true, true).sameElements(maxRowLevel))
assert(Seq(true, true, true, null, null, null).sameElements(maxRowLevel))

// filtered rows 4, 5, 6 (where item < 4)
val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0))
Expand Down Expand Up @@ -1609,6 +1606,121 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
}
}

"Verification Suite with == based Min/Max checks and filtered row behavior" should {
val col1 = "att1"
val col2 = "att2"
val col3 = "att3"

val check1Description = "equality-check-1"
val check2Description = "equality-check-2"
val check3Description = "equality-check-3"

val check1WhereClause = "att1 > 3"
val check2WhereClause = "att2 > 4"
val check3WhereClause = "att3 = 0"

def mkEqualityCheck1(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check1Description)
.hasMin(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause)
.hasMax(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause)

def mkEqualityCheck2(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check2Description)
.hasMin(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause)
.hasMax(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause)

def mkEqualityCheck3(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check3Description)
.hasMin(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause)
.hasMax(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause)

def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] =
df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq

def assertCheckResults(verificationResult: VerificationResult): Unit = {
val passResult = verificationResult.checkResults
val equalityCheck1Result = passResult.values.find(_.check.description == check1Description)
val equalityCheck2Result = passResult.values.find(_.check.description == check2Description)
val equalityCheck3Result = passResult.values.find(_.check.description == check3Description)

assert(equalityCheck1Result.isDefined && equalityCheck1Result.get.status == CheckStatus.Error)
assert(equalityCheck2Result.isDefined && equalityCheck2Result.get.status == CheckStatus.Error)
assert(equalityCheck3Result.isDefined && equalityCheck3Result.get.status == CheckStatus.Success)
}

def assertRowLevelResults(rowLevelResults: DataFrame,
analyzerOptions: AnalyzerOptions): Unit = {
val equalityCheck1Results = getRowLevelResults(rowLevelResults.select(check1Description))
val equalityCheck2Results = getRowLevelResults(rowLevelResults.select(check2Description))
val equalityCheck3Results = getRowLevelResults(rowLevelResults.select(check3Description))

val filteredOutcome: java.lang.Boolean = analyzerOptions.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
assert(equalityCheck1Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, true, false, false))
assert(equalityCheck2Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, false, false, true))
assert(equalityCheck3Results == Seq(true, true, true, filteredOutcome, filteredOutcome, filteredOutcome))
}

def assertMetrics(metricsDF: DataFrame): Unit = {
val metricsMap: Map[String, Double] = metricsDF.collect().map { r =>
val colName = r.getAs[String]("instance")
val metricName = r.getAs[String]("name")
val metricValue = r.getAs[Double]("value")
s"$colName|$metricName" -> metricValue
}.toMap

assert(metricsMap(s"$col1|Minimum (where: $check1WhereClause)") == 4.0)
assert(metricsMap(s"$col1|Maximum (where: $check1WhereClause)") == 6.0)
assert(metricsMap(s"$col2|Minimum (where: $check2WhereClause)") == 5.0)
assert(metricsMap(s"$col2|Maximum (where: $check2WhereClause)") == 7.0)
assert(metricsMap(s"$col3|Minimum (where: $check3WhereClause)") == 0.0)
assert(metricsMap(s"$col3|Maximum (where: $check3WhereClause)") == 0.0)
}

"mark filtered rows to null" in withSparkSession {
sparkSession =>
val df = getDfWithNumericValues(sparkSession)
val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)

val equalityCheck1 = mkEqualityCheck1(analyzerOptions)
val equalityCheck2 = mkEqualityCheck2(analyzerOptions)
val equalityCheck3 = mkEqualityCheck3(analyzerOptions)

val verificationResult = VerificationSuite()
.onData(df)
.addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3))
.run()

val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df)
val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult)

assertCheckResults(verificationResult)
assertRowLevelResults(rowLevelResultsDF, analyzerOptions)
assertMetrics(metricsDF)
}

"mark filtered rows to true" in withSparkSession {
sparkSession =>
val df = getDfWithNumericValues(sparkSession)
val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)

val equalityCheck1 = mkEqualityCheck1(analyzerOptions)
val equalityCheck2 = mkEqualityCheck2(analyzerOptions)
val equalityCheck3 = mkEqualityCheck3(analyzerOptions)

val verificationResult = VerificationSuite()
.onData(df)
.addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3))
.run()

val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df)
val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult)

assertCheckResults(verificationResult)
assertRowLevelResults(rowLevelResultsDF, analyzerOptions)
assertMetrics(metricsDF)
}
}

/** Run anomaly detection using a repository with some previous analysis results for testing */
private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = {

Expand Down
44 changes: 21 additions & 23 deletions src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {

"Max" should {
"return row-level results for columns" in withSparkSession { session =>

Expand All @@ -52,35 +51,34 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F
Seq(null, null, null, 5.0, 6.0, 7.0)
}

"return row-level results for columns with where clause filtered as true" in withSparkSession { session =>

"return row-level results for columns with filtered rows" in withSparkSession { session =>
val data = getDfWithNumericValues(session)
val col = "att1"
val whereClause = "item < 4"
val tempColName = "new"

val att1Maximum = Maximum("att1", Option("item < 4"))
val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4"))
val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state)
val analyzerOptionsFilteredRowsNull = AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)
val analyzerOptionsFilteredRowsTrue = AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)

val result = data.withColumn("new", metric.fullColumn.get)
result.show(false)
result.collect().map(r =>
if (r == null) null else r.getAs[Double]("new")) shouldBe
Seq(1.0, 2.0, 3.0, Double.MinValue, Double.MinValue, Double.MinValue)
}
val att1MaximumFilteredRowsNull = Maximum(col, Option(whereClause), Some(analyzerOptionsFilteredRowsNull))
val att1MaximumFilteredRowsTrue = Maximum(col, Option(whereClause), Some(analyzerOptionsFilteredRowsTrue))

"return row-level results for columns with where clause filtered as null" in withSparkSession { session =>
val filteredRowNullState = att1MaximumFilteredRowsNull.computeStateFrom(data, Option(whereClause))
val filteredRowTrueState = att1MaximumFilteredRowsTrue.computeStateFrom(data, Option(whereClause))

val data = getDfWithNumericValues(session)
val filteredRowNullMetric: DoubleMetric with FullColumn =
att1MaximumFilteredRowsNull.computeMetricFrom(filteredRowNullState)
val filteredRowTrueMetric: DoubleMetric with FullColumn =
att1MaximumFilteredRowsTrue.computeMetricFrom(filteredRowTrueState)

val att1Maximum = Maximum("att1", Option("item < 4"),
Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)))
val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4"))
val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state)
val filteredRowNullResult = data.withColumn(tempColName, filteredRowNullMetric.fullColumn.get)
val filteredRowTrueResult = data.withColumn(tempColName, filteredRowTrueMetric.fullColumn.get)

val result = data.withColumn("new", metric.fullColumn.get)
result.show(false)
result.collect().map(r =>
if (r == null) null else r.getAs[Double]("new")) shouldBe
Seq(1.0, 2.0, 3.0, null, null, null)
Seq(filteredRowNullResult, filteredRowTrueResult).foreach { result =>
result.collect().map(r =>
if (r == null) null else r.getAs[Double](tempColName)) shouldBe
Seq(1.0, 2.0, 3.0, null, null, null)
}
}
}
}
Loading

0 comments on commit 42d2425

Please sign in to comment.