diff --git a/core/src/main/scala/io/projectglow/functions.scala b/core/src/main/scala/io/projectglow/functions.scala index df318407f..5fd01be41 100644 --- a/core/src/main/scala/io/projectglow/functions.scala +++ b/core/src/main/scala/io/projectglow/functions.scala @@ -317,6 +317,24 @@ object functions { new io.projectglow.sql.expressions.SampleGqSummaryStatistics(genotypes.expr) } + /** + * Array quantile + * @group quality_control + * @since 2.1.0 + * + * @param arr An array of numeric values + * @param quantile The desired quantile + * @param is_sorted If true, the input array is assumed to already be sorted + * @return + */ + def array_quantile(arr: Column, quantile: Double, is_sorted: Column): Column = withExpr { + new io.projectglow.sql.expressions.ArrayQuantile(arr.expr, Literal(quantile), is_sorted.expr) + } + + def array_quantile(arr: Column, quantile: Double): Column = withExpr { + new io.projectglow.sql.expressions.ArrayQuantile(arr.expr, Literal(quantile)) + } + /** * Performs a linear regression association test optimized for performance in a GWAS setting. See :ref:`linear-regression` for details. * @group gwas_functions diff --git a/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala b/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala index 6a8551d24..55e82119c 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} -import org.apache.spark.sql.catalyst.expressions.{Add, Alias, BinaryExpression, CaseWhen, Cast, CreateNamedStruct, Divide, EqualTo, Exp, ExpectsInputTypes, Expression, Factorial, Generator, GenericInternalRow, GetStructField, If, ImplicitCastInputTypes, LessThan, Literal, Log, Multiply, NamedExpression, Pi, Round, Subtract, UnaryExpression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArraySort, BinaryExpression, CaseWhen, Cast, Ceil, CreateNamedStruct, Divide, EqualTo, Exp, ExpectsInputTypes, Expression, Factorial, Floor, Generator, GenericInternalRow, GetArrayItem, GetStructField, Greatest, If, ImplicitCastInputTypes, Least, LessThan, Literal, Log, Multiply, NamedExpression, Pi, Round, Size, Subtract, UnaryExpression, Unevaluable} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import io.projectglow.SparkShim.newUnresolvedException @@ -314,5 +314,28 @@ case class LogFactorial(n: Expression) extends RewriteAfterResolution { override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(n = newChildren(0)) } +} + +case class ArrayQuantile(arr: Expression, probability: Expression, isSorted: Expression) + extends RewriteAfterResolution { + + def this(arr: Expression, probability: Expression) = this(arr, probability, Literal(false)) + override def children: Seq[Expression] = Seq(arr, probability, isSorted) + + private def getQuantile(arr: Expression): Expression = { + val trueIndex = Add(Multiply(probability, Subtract(Size(arr), Literal(1))), Literal(1)) + val roundedIdx = Cast(trueIndex, IntegerType) + val below = GetArrayItem(arr, Greatest(Seq(Literal(0), Subtract(roundedIdx, Literal(1))))) + val above = GetArrayItem(arr, Least(Seq(Subtract(Size(arr), Literal(1)), roundedIdx))) + val frac = Subtract(trueIndex, roundedIdx) + Add(Multiply(frac, above), Multiply(Subtract(Literal(1), frac), below)) + } + override def rewrite: Expression = { + If(isSorted, getQuantile(arr), getQuantile(new ArraySort(arr))) + } + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { + copy(arr = newChildren(0), probability = newChildren(1), isSorted = newChildren(2)) + } } diff --git a/core/src/test/scala/io/projectglow/tertiary/VariantUtilExprsSuite.scala b/core/src/test/scala/io/projectglow/tertiary/VariantUtilExprsSuite.scala index eb0cc8f08..92e5fa2d3 100644 --- a/core/src/test/scala/io/projectglow/tertiary/VariantUtilExprsSuite.scala +++ b/core/src/test/scala/io/projectglow/tertiary/VariantUtilExprsSuite.scala @@ -16,6 +16,7 @@ package io.projectglow.tertiary +import com.google.common.math.Quantiles import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.functions._ @@ -24,6 +25,13 @@ import org.apache.spark.unsafe.types.UTF8String import io.projectglow.sql.GlowBaseTest import io.projectglow.sql.expressions.{VariantType, VariantUtilExprs} import io.projectglow.functions._ +import org.apache.commons.math3.stat.descriptive.rank.Percentile +import org.apache.commons.math3.stat.descriptive.rank.Percentile.EstimationType +import org.apache.commons.math3.stat.ranking.NaNStrategy +import org.apache.commons.math3.util.{KthSelector, MedianOf3PivotingStrategy} +import org.scalactic.TolerantNumerics + +import scala.util.Random class VariantUtilExprsSuite extends GlowBaseTest { case class SimpleGenotypeFields(calls: Seq[Int]) @@ -334,6 +342,78 @@ class VariantUtilExprsSuite extends GlowBaseTest { val df = spark.createDataFrame(Seq(Outer(Inner(1, "two")))) assert(df.select(expand_struct(col("inner"))).as[Inner].head == Inner(1, "two")) } + + case class QuantileTest( + arr: Seq[Double], + p25: Double, + p50: Double, + p75: Double, + p90: Double, + p99: Double) + test("quantiles") { + def checkDf(df: DataFrame): Unit = { + val rows = df.collect() + rows.foreach { row => + row.getAs[Double]("p25") ~== row.getAs[Double]("glow_25") relTol 0.02 + row.getAs[Double]("p50") ~== row.getAs[Double]("glow_50") relTol 0.02 + row.getAs[Double]("p75") ~== row.getAs[Double]("glow_75") relTol 0.02 + row.getAs[Double]("p90") ~== row.getAs[Double]("glow_90") relTol 0.02 + row.getAs[Double]("p99") ~== row.getAs[Double]("glow_99") relTol 0.02 + } + } + val cases = Range(0, 50).map { n => + val numbers = Range(0, (Random.nextDouble() * 1000).toInt).map(_ => Random.nextDouble()) + val evaluator = new Percentile(1).withEstimationType(EstimationType.R_7) + val golden = Seq(25, 50, 75, 90, 99).map { d => + evaluator.setQuantile(d) + d -> evaluator.evaluate(numbers.toArray) + }.toMap + + QuantileTest(numbers, golden(25), golden(50), golden(75), golden(90), golden(99)) + + } + + val df = spark.createDataFrame(cases) + + // Unsorted + val glowUnsortedQuantiles = df.withColumns( + Map( + "glow_25" -> expr("array_quantile(arr, 0.25)"), + "glow_50" -> expr("array_quantile(arr, 0.50)"), + "glow_75" -> expr("array_quantile(arr, 0.75)"), + "glow_90" -> expr("array_quantile(arr, 0.90)"), + "glow_99" -> expr("array_quantile(arr, 0.99)") + )) + checkDf(glowUnsortedQuantiles) + + val sortedDf = df + .withColumn("sorted_arr", expr("array_sort(arr)")) + .withColumns(Map( + "glow_25" -> expr("array_quantile(sorted_arr, 0.25, true)"), + "glow_50" -> expr("array_quantile(sorted_arr, 0.50, true)"), + "glow_75" -> expr("array_quantile(sorted_arr, 0.75, true)"), + "glow_90" -> expr("array_quantile(sorted_arr, 0.90, true)"), + "glow_99" -> expr("array_quantile(sorted_arr, 0.99, true)") + )) + checkDf(sortedDf) + } + + test("quantiles respects sorted argument") { + val df = spark.createDataFrame(Seq(Tuple1(Seq(4, 3, 2, 1)))).withColumnRenamed("_1", "arr") + assert(df.selectExpr("array_quantile(arr, 1, true)").first().get(0) == 1) + assert(df.selectExpr("array_quantile(arr, 1, false)").first().get(0) == 4) + } + + test("quantiles 0 length array") { + val df = spark.createDataFrame(Seq(Tuple1(Seq()))).withColumnRenamed("_1", "arr") + assert(df.selectExpr("array_quantile(arr, 1, true)").first().get(0) == null) + } + + test("quantiles 1 length array") { + val df = spark.createDataFrame(Seq(Tuple1(Seq(5)))).withColumnRenamed("_1", "arr") + assert(df.selectExpr("cast(array_quantile(arr, 1, true) as int)").first().get(0) == 5) + assert(df.selectExpr("cast(array_quantile(arr, 0.001, true) as int)").first().get(0) == 5) + } } case class HCTestCase( diff --git a/functions.yml b/functions.yml index 5708f9ce2..f2222a8e4 100644 --- a/functions.yml +++ b/functions.yml @@ -406,6 +406,25 @@ quality_control: type: str returns: Null if true, or throws an exception if not true + - name: array_quantile + doc: Array quantile + since: 2.1.0 + expr_class: io.projectglow.sql.expressions.ArrayQuantile + args: + - name: arr + doc: An array of numeric values + - name: quantile + doc: The desired quantile + type: double + - name: is_sorted + doc: If true, the input array is assumed to already be sorted + is_optional: true + examples: + python: | + >>> df = spark.createDataFrame([Row(arr=[1, 2, 3, 4, 5])]) + >>> df.select(glow.array_quantile(df.arr, 0.7).alias('p70')).collect() + [Row(p70=3.8)] + gwas_functions: functions: diff --git a/python/glow/functions.py b/python/glow/functions.py index 1520b261b..84097b8a5 100644 --- a/python/glow/functions.py +++ b/python/glow/functions.py @@ -547,6 +547,34 @@ def sample_gq_summary_stats(genotypes: Union[Column, str]) -> Column: output = Column(sc()._jvm.io.projectglow.functions.sample_gq_summary_stats(_to_java_column(genotypes))) return output + +__all__.append('array_quantile') +@typechecked +def array_quantile(arr: Union[Column, str], quantile: float, is_sorted: Union[Column, str] | None = None) -> Column: + """ + Array quantile + + Added in version 2.1.0. + + Examples: + >>> df = spark.createDataFrame([Row(arr=[1, 2, 3, 4, 5])]) + >>> df.select(glow.array_quantile(df.arr, 0.7).alias('p70')).collect() + [Row(p70=3.8)] + + Args: + arr : An array of numeric values + quantile : The desired quantile + is_sorted : If true, the input array is assumed to already be sorted + + Returns: + + """ + if is_sorted is None: + output = Column(sc()._jvm.io.projectglow.functions.array_quantile(_to_java_column(arr), quantile)) + else: + output = Column(sc()._jvm.io.projectglow.functions.array_quantile(_to_java_column(arr), quantile, _to_java_column(is_sorted))) + return output + ########### gwas_functions __all__.append('linear_regression_gwas')