Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Min/Max] Apply filtered row behavior at the row level evaluation #543

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MaxState] with FullColumn {
Expand All @@ -49,7 +47,7 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {

ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(rowLevelResults))
MaxState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -61,18 +59,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions

@VisibleForTesting
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType)
case _ =>
criterion
}
}

}

20 changes: 2 additions & 18 deletions src/main/scala/com/amazon/deequ/analyzers/Minimum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MinState(minValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MinState] with FullColumn {
Expand All @@ -48,7 +46,7 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {
ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(rowLevelResults))
MinState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -59,19 +57,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = {
conditionalSelection(column, where).cast(DoubleType)
}

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType)
case _ =>
criterion
}
}
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)
}
22 changes: 20 additions & 2 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,9 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MinState, Double, Double](minimum, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertion(assertion, minimum.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
s"MinimumConstraint($minimum)",
Expand Down Expand Up @@ -663,7 +665,9 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maximum, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertion(assertion, maximum.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
s"MaximumConstraint($maximum)",
Expand Down Expand Up @@ -916,6 +920,20 @@ object Constraint {
.getOrElse(0.0)
}

private[this] def getUpdatedRowLevelAssertion(assertion: Double => Boolean,
analyzerOpts: Option[AnalyzerOptions])
: java.lang.Double => java.lang.Boolean = {
(d: java.lang.Double) => {
if (Option(d).isDefined) assertion(d)
else analyzerOpts match {
case Some(analyzerOptions) => analyzerOptions.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
case None => null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had discussed that the default analyzerOptions behavior should be FilteredRowOutcome.TRUE, should we modify 933 to be true? (By default filtered rows are true instead of null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was testing out another scenario that may be problematic.
Given the following dataframe:

+------+----+----+----+----+-----+-----+
|item  |att1|att2|val1|val2|rule4|rule5|
+------+----+----+----+----+-----+-----+
|1     |a   |f   |1   |1   |true |true |
|22    |b   |d   |2   |NULL|true |true |
|333   |a   |NULL|3   |3   |true |true |
|4444  |a   |f   |4   |4   |true |true |
|55555 |b   |NULL|5   |NULL|true |true |
|666666|a   |f   |6   |6   |true |true |
+------+----+----+----+----+-----+-----+

where

val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))

val min = new Check(CheckLevel.Error, "rule4")
  .hasMin("val2", _ > 0, None, analyzerOptions)
  .where("val1 < 4")
val max = new Check(CheckLevel.Error, "rule5")
  .hasMax("val2", _ < 4, None, analyzerOptions)
  .where("val1 < 4")

You'll see that rows 1,2,3 should be skipped -> True
Row 5 should be null as val2 is a null value there.
However, with the above method we convert all nulls to true/null - this doesn't distinguish between null values due to being filtered or null values due to null column values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @eycho-am for the valuable feedback. The latest PR revision contains a new structure for the column that helps maintain the "source" of a row, whether it is in scope and filtered out. That will help in evaluating the correct outcome for each row.

}
}
}
}

/**
Expand Down
122 changes: 117 additions & 5 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,10 @@ import org.scalamock.scalatest.MockFactory
import org.scalatest.Matchers
import org.scalatest.WordSpec



class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
with FixtureSupport with MockFactory {

"Verification Suite" should {

"return the correct verification status regardless of the order of checks" in
withSparkSession { sparkSession =>

Expand Down Expand Up @@ -374,11 +371,11 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec

// filtered rows 1, 2, 3 (where item > 3)
val minRowLevel = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, true, true, true, true).sameElements(minRowLevel))
assert(Seq(null, null, null, true, true, true).sameElements(minRowLevel))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these test were written with the intention that without specifying analyzer options, the default behavior would be filtered rows are true - related to above comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted the change.


// filtered rows 4, 5, 6 (where item < 4)
val maxRowLevel = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, true, true, true, true).sameElements(maxRowLevel))
assert(Seq(true, true, true, null, null, null).sameElements(maxRowLevel))

// filtered rows 4, 5, 6 (where item < 4)
val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0))
Expand Down Expand Up @@ -1609,6 +1606,121 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
}
}

"Verification Suite with == based Min/Max checks and filtered row behavior" should {
val col1 = "att1"
val col2 = "att2"
val col3 = "att3"

val check1Description = "equality-check-1"
val check2Description = "equality-check-2"
val check3Description = "equality-check-3"

val check1WhereClause = "att1 > 3"
val check2WhereClause = "att2 > 4"
val check3WhereClause = "att3 = 0"

def mkEqualityCheck1(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check1Description)
.hasMin(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause)
.hasMax(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause)

def mkEqualityCheck2(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check2Description)
.hasMin(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause)
.hasMax(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause)

def mkEqualityCheck3(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check3Description)
.hasMin(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause)
.hasMax(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause)

def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] =
df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq

def assertCheckResults(verificationResult: VerificationResult): Unit = {
val passResult = verificationResult.checkResults
val equalityCheck1Result = passResult.values.find(_.check.description == check1Description)
val equalityCheck2Result = passResult.values.find(_.check.description == check2Description)
val equalityCheck3Result = passResult.values.find(_.check.description == check3Description)

assert(equalityCheck1Result.isDefined && equalityCheck1Result.get.status == CheckStatus.Error)
assert(equalityCheck2Result.isDefined && equalityCheck2Result.get.status == CheckStatus.Error)
assert(equalityCheck3Result.isDefined && equalityCheck3Result.get.status == CheckStatus.Success)
}

def assertRowLevelResults(rowLevelResults: DataFrame,
analyzerOptions: AnalyzerOptions): Unit = {
val equalityCheck1Results = getRowLevelResults(rowLevelResults.select(check1Description))
val equalityCheck2Results = getRowLevelResults(rowLevelResults.select(check2Description))
val equalityCheck3Results = getRowLevelResults(rowLevelResults.select(check3Description))

val filteredOutcome: java.lang.Boolean = analyzerOptions.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
assert(equalityCheck1Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, true, false, false))
assert(equalityCheck2Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, false, false, true))
assert(equalityCheck3Results == Seq(true, true, true, filteredOutcome, filteredOutcome, filteredOutcome))
}

def assertMetrics(metricsDF: DataFrame): Unit = {
val metricsMap: Map[String, Double] = metricsDF.collect().map { r =>
val colName = r.getAs[String]("instance")
val metricName = r.getAs[String]("name")
val metricValue = r.getAs[Double]("value")
s"$colName|$metricName" -> metricValue
}.toMap

assert(metricsMap(s"$col1|Minimum (where: $check1WhereClause)") == 4.0)
assert(metricsMap(s"$col1|Maximum (where: $check1WhereClause)") == 6.0)
assert(metricsMap(s"$col2|Minimum (where: $check2WhereClause)") == 5.0)
assert(metricsMap(s"$col2|Maximum (where: $check2WhereClause)") == 7.0)
assert(metricsMap(s"$col3|Minimum (where: $check3WhereClause)") == 0.0)
assert(metricsMap(s"$col3|Maximum (where: $check3WhereClause)") == 0.0)
}

"mark filtered rows to null" in withSparkSession {
sparkSession =>
val df = getDfWithNumericValues(sparkSession)
val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)

val equalityCheck1 = mkEqualityCheck1(analyzerOptions)
val equalityCheck2 = mkEqualityCheck2(analyzerOptions)
val equalityCheck3 = mkEqualityCheck3(analyzerOptions)

val verificationResult = VerificationSuite()
.onData(df)
.addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3))
.run()

val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df)
val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult)

assertCheckResults(verificationResult)
assertRowLevelResults(rowLevelResultsDF, analyzerOptions)
assertMetrics(metricsDF)
}

"mark filtered rows to true" in withSparkSession {
sparkSession =>
val df = getDfWithNumericValues(sparkSession)
val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)

val equalityCheck1 = mkEqualityCheck1(analyzerOptions)
val equalityCheck2 = mkEqualityCheck2(analyzerOptions)
val equalityCheck3 = mkEqualityCheck3(analyzerOptions)

val verificationResult = VerificationSuite()
.onData(df)
.addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3))
.run()

val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df)
val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult)

assertCheckResults(verificationResult)
assertRowLevelResults(rowLevelResultsDF, analyzerOptions)
assertMetrics(metricsDF)
}
}

/** Run anomaly detection using a repository with some previous analysis results for testing */
private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = {

Expand Down
44 changes: 21 additions & 23 deletions src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {

"Max" should {
"return row-level results for columns" in withSparkSession { session =>

Expand All @@ -52,35 +51,34 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F
Seq(null, null, null, 5.0, 6.0, 7.0)
}

"return row-level results for columns with where clause filtered as true" in withSparkSession { session =>

"return row-level results for columns with filtered rows" in withSparkSession { session =>
val data = getDfWithNumericValues(session)
val col = "att1"
val whereClause = "item < 4"
val tempColName = "new"

val att1Maximum = Maximum("att1", Option("item < 4"))
val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4"))
val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state)
val analyzerOptionsFilteredRowsNull = AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)
val analyzerOptionsFilteredRowsTrue = AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)

val result = data.withColumn("new", metric.fullColumn.get)
result.show(false)
result.collect().map(r =>
if (r == null) null else r.getAs[Double]("new")) shouldBe
Seq(1.0, 2.0, 3.0, Double.MinValue, Double.MinValue, Double.MinValue)
}
val att1MaximumFilteredRowsNull = Maximum(col, Option(whereClause), Some(analyzerOptionsFilteredRowsNull))
val att1MaximumFilteredRowsTrue = Maximum(col, Option(whereClause), Some(analyzerOptionsFilteredRowsTrue))

"return row-level results for columns with where clause filtered as null" in withSparkSession { session =>
val filteredRowNullState = att1MaximumFilteredRowsNull.computeStateFrom(data, Option(whereClause))
val filteredRowTrueState = att1MaximumFilteredRowsTrue.computeStateFrom(data, Option(whereClause))

val data = getDfWithNumericValues(session)
val filteredRowNullMetric: DoubleMetric with FullColumn =
att1MaximumFilteredRowsNull.computeMetricFrom(filteredRowNullState)
val filteredRowTrueMetric: DoubleMetric with FullColumn =
att1MaximumFilteredRowsTrue.computeMetricFrom(filteredRowTrueState)

val att1Maximum = Maximum("att1", Option("item < 4"),
Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)))
val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4"))
val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state)
val filteredRowNullResult = data.withColumn(tempColName, filteredRowNullMetric.fullColumn.get)
val filteredRowTrueResult = data.withColumn(tempColName, filteredRowTrueMetric.fullColumn.get)

val result = data.withColumn("new", metric.fullColumn.get)
result.show(false)
result.collect().map(r =>
if (r == null) null else r.getAs[Double]("new")) shouldBe
Seq(1.0, 2.0, 3.0, null, null, null)
Seq(filteredRowNullResult, filteredRowTrueResult).foreach { result =>
result.collect().map(r =>
if (r == null) null else r.getAs[Double](tempColName)) shouldBe
Seq(1.0, 2.0, 3.0, null, null, null)
}
}
}
}
Loading
Loading