Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Min/Max] Apply filtered row behavior at the row level evaluation #543

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,13 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}
}

sealed trait RowLevelStatusSource { def name: String }
case object InScopeData extends RowLevelStatusSource { val name = "InScopeData" }
case object FilteredData extends RowLevelStatusSource { val name = "FilteredData" }

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore,
filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE)

object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
Expand Down Expand Up @@ -478,46 +483,55 @@ private[deequ] object Analyzers {
if (columns.size == 1) Entity.Column else Entity.Multicolumn
}

def conditionalSelection(selection: String, where: Option[String]): Column = {
conditionalSelection(col(selection), where)
def conditionalSelection(selection: String, condition: Option[String]): Column = {
conditionalSelection(col(selection), condition)
}

def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Double): Column = {
where
def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: Double): Column = {
condition
.map { condition => when(condition, replaceWith).otherwise(selection) }
.getOrElse(selection)
}

def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: String): Column = {
where
def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: String): Column = {
condition
.map { condition => when(condition, replaceWith).otherwise(selection) }
.getOrElse(selection)
}

def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Boolean): Column = {
where
def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: Boolean): Column = {
condition
.map { condition => when(condition, replaceWith).otherwise(selection) }
.getOrElse(selection)
}

def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = {
conditionSelectionGivenColumn(selection, where.map(expr), replaceWith)
def conditionalSelection(selection: Column, condition: Option[String], replaceWith: Double): Column = {
conditionSelectionGivenColumn(selection, condition.map(expr), replaceWith)
}

def conditionalSelection(selection: Column, where: Option[String], replaceWith: String): Column = {
conditionSelectionGivenColumn(selection, where.map(expr), replaceWith)
def conditionalSelection(selection: Column, condition: Option[String], replaceWith: String): Column = {
conditionSelectionGivenColumn(selection, condition.map(expr), replaceWith)
}

def conditionalSelection(selection: Column, condition: Option[String]): Column = {
val conditionColumn = condition.map { expression => expr(expression) }
conditionalSelectionFromColumns(selection, conditionColumn)
}

def conditionalSelectionFilteredFromColumns(
selection: Column,
conditionColumn: Option[Column],
filterTreatment: FilteredRowOutcome)
: Column = {
def conditionalSelectionWithAugmentedOutcome(selection: Column,
condition: Option[String],
replaceWith: Double): Column = {
val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection"))
val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection"))

condition
.map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) }
.getOrElse(origSelection)
}

def conditionalSelectionFilteredFromColumns(selection: Column,
conditionColumn: Option[Column],
filterTreatment: FilteredRowOutcome): Column = {
conditionColumn
.map { condition =>
when(not(condition), filterTreatment.getExpression).when(condition, selection)
Expand Down
25 changes: 4 additions & 21 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric}
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, max}
import org.apache.spark.sql.functions.{col, element_at, max}
import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MaxState] with FullColumn {
Expand All @@ -43,13 +41,12 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
max(criterion) :: Nil
max(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {

ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(rowLevelResults))
MaxState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -60,19 +57,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType)
case _ =>
criterion
}
}

private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue)
}

24 changes: 4 additions & 20 deletions src/main/scala/com/amazon/deequ/analyzers/Minimum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric}
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, min}
import org.apache.spark.sql.functions.{col, element_at, min}
import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MinState(minValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MinState] with FullColumn {
Expand All @@ -43,12 +41,12 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
min(criterion) :: Nil
min(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {
ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(rowLevelResults))
MinState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -59,19 +57,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = {
conditionalSelection(column, where).cast(DoubleType)
}

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType)
case _ =>
criterion
}
}
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue)
}
61 changes: 59 additions & 2 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,9 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MinState, Double, Double](minimum, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertion(assertion, minimum.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
s"MinimumConstraint($minimum)",
Expand Down Expand Up @@ -663,7 +665,9 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maximum, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertion(assertion, maximum.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
s"MaximumConstraint($maximum)",
Expand Down Expand Up @@ -916,6 +920,59 @@ object Constraint {
.getOrElse(0.0)
}


/*
* This function is used by Min/Max constraints and it creates a new assertion based on the provided assertion.
* Each value in the outcome column is an array of 2 elements.
* - The first element is a string that denotes whether the row is the filtered dataset or not.
* - The second element is the actual value of the constraint's target column.
* The result of the final assertion is one of 3 states: true, false or null.
* These values can be tuned using the analyzer options.
* Null outcome allows the consumer to decide how to treat filtered rows or rows that were originally null.
*/
private[this] def getUpdatedRowLevelAssertion(assertion: Double => Boolean,
analyzerOptions: Option[AnalyzerOptions])
: Seq[String] => java.lang.Boolean = {
(d: Seq[String]) => {
val (scope, value) = (d.head, Option(d.last).map(_.toDouble))

def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = {
if (value.isDefined) {
// If value is defined, run it through the assertion.
assertion(value.get)
} else {
// If value is not defined (value is null), apply NullBehavior.
analyzerOptions match {
case Some(opts) =>
opts.nullBehavior match {
case NullBehavior.Fail => false
case NullBehavior.Ignore => null
}
case None => null
}
}
}

def filteredRowOutcome: java.lang.Boolean = {
analyzerOptions match {
case Some(opts) =>
opts.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
// https://github.com/awslabs/deequ/issues/530
// Filtered rows should be marked as true by default.
// They can be set to null using the FilteredRowOutcome option.
case None => true
}
}

scope match {
case FilteredData.name => filteredRowOutcome
case InScopeData.name => inScopeRowOutcome(value)
}
}
}
}

/**
Expand Down
Loading
Loading