Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RatioOfSums analyzer and tests #1

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric}
import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.DeequFunctions.stateful_corr
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import Analyzers._

import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.repository.AnalysisResultSerde

case class RatioOfSumsState(
numerator: Double,
denominator: Double
) extends DoubleValuedState[RatioOfSumsState] {

override def sum(other: RatioOfSumsState): RatioOfSumsState = {
val n1 = numerator
val n2 = other.numerator
val newN = n1 + n2
val t1 = denominator
val t2 = other.denominator
val newD = t1 + t2

RatioOfSumsState(newN, newD)
akalotkin marked this conversation as resolved.
Show resolved Hide resolved
}

override def metricValue(): Double = {
numerator / denominator
}
}

/** Sums up 2 columns and then divides the final values
*
* @param numerator
* First input column for computation
* @param denominator
* Second input column for computation
*/
case class RatioOfSums(
numerator: String,
denominator: String,
where: Option[String] = None
) extends StandardScanShareableAnalyzer[RatioOfSumsState](
"RatioOfSums",
s"$numerator,$denominator",
Entity.Multicolumn
)
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
val firstSelection = conditionalSelection(numerator, where)
val secondSelection = conditionalSelection(denominator, where)
sum(firstSelection).cast(DoubleType) :: sum(secondSelection).cast(DoubleType) :: Nil
}

override def fromAggregationResult(
result: Row,
offset: Int
): Option[RatioOfSumsState] = {
if (result.isNullAt(offset)) {
None
} else {
Some(
RatioOfSumsState(
result.getDouble(0),
result.getDouble(1)
)
)
}
}

override protected def additionalPreconditions(): Seq[StructType => Unit] = {
hasColumn(numerator) :: isNumeric(numerator) :: hasColumn(denominator) :: isNumeric(denominator) :: Nil
}

override def filterCondition: Option[String] = where
}
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ private[deequ] object AnalyzerSerializer
result.addProperty(COLUMN_FIELD, sum.column)
result.addProperty(WHERE_FIELD, sum.where.orNull)

case ratioOfSums: RatioOfSums =>
result.addProperty(ANALYZER_NAME_FIELD, "RatioOfSums")
result.addProperty("numerator", ratioOfSums.numerator)
akalotkin marked this conversation as resolved.
Show resolved Hide resolved
result.addProperty("denominator", ratioOfSums.denominator)
result.addProperty(WHERE_FIELD, ratioOfSums.where.orNull)

case mean: Mean =>
result.addProperty(ANALYZER_NAME_FIELD, "Mean")
result.addProperty(COLUMN_FIELD, mean.column)
Expand Down Expand Up @@ -412,6 +418,12 @@ private[deequ] object AnalyzerDeserializer
json.get(COLUMN_FIELD).getAsString,
getOptionalWhereParam(json))

case "RatioOfSums" =>
RatioOfSums(
json.get("numerator").getAsString,
json.get("denominator").getAsString,
getOptionalWhereParam(json))

case "Mean" =>
Mean(
json.get(COLUMN_FIELD).getAsString,
Expand Down
11 changes: 11 additions & 0 deletions src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,17 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with
analyzer.calculate(df).value shouldBe Success(2.0 / 8.0)
assert(analyzer.calculate(df).fullColumn.isDefined)
}

"compute ratio of sums correctly for numeric data" in withSparkSession { session =>
akalotkin marked this conversation as resolved.
Show resolved Hide resolved
val df = getDfWithNumericValues(session)
RatioOfSums("att1", "att2").calculate(df).value shouldBe Success(21.0 / 18.0)
}

"fail to compute ratio of sums for non numeric type" in withSparkSession { sparkSession =>
val df = getDfFull(sparkSession)
assert(RatioOfSums("att1", "att2").calculate(df).value.isFailure)
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class AnalysisResultSerdeTest extends FlatSpec with Matchers {
DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)),
Sum("ColumnA") ->
DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)),
RatioOfSums("ColumnA", "ColumnB") ->
DoubleMetric(Entity.Column, "RatioOfSums", "ColumnA", Success(5.0)),
StandardDeviation("ColumnA") ->
DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)),
DataType("ColumnA") ->
Expand Down