diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b657f909..e5eb28e5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -218,6 +218,9 @@ jobs: - name: Build artifacts run: bin/build --scala --python + - name: Test assembly jar + run: java -cp core/target/**/glow*assembly*.jar io.projectglow.TestAssemblyJar + - name: Upload artifacts uses: actions/upload-artifact@v4 if: success() || failure() diff --git a/build.sbt b/build.sbt index a0354210..1a0b9714 100644 --- a/build.sbt +++ b/build.sbt @@ -111,7 +111,6 @@ lazy val commonSettings = Seq( Test / test := ((Test / test) dependsOn (Test / headerCheck)).value, assembly / test := {}, assembly / assemblyMergeStrategy := { - // Assembly jar is not executable case p if p.toLowerCase.contains("manifest.mf") => MergeStrategy.discard case _ => @@ -184,7 +183,7 @@ ThisBuild / coreDependencies := (providedSparkDependencies.value ++ testCoreDepe "com.github.broadinstitute" % "picard" % "2.27.5", "org.apache.commons" % "commons-lang3" % "3.14.0", // Fix versions of libraries that are depended on multiple times - "org.apache.hadoop" % "hadoop-client" % "3.4.0", + "org.apache.hadoop" % "hadoop-client" % "3.3.6", "io.netty" % "netty-all" % "4.1.96.Final", "io.netty" % "netty-handler" % "4.1.96.Final", "io.netty" % "netty-transport-native-epoll" % "4.1.96.Final", diff --git a/core/src/main/scala/io/projectglow/TestAssemblyJar.scala b/core/src/main/scala/io/projectglow/TestAssemblyJar.scala new file mode 100644 index 00000000..856bb40e --- /dev/null +++ b/core/src/main/scala/io/projectglow/TestAssemblyJar.scala @@ -0,0 +1,23 @@ +/* + * Copyright 2019 The Glow Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.projectglow + +object TestAssemblyJar { + def main(args: Array[String]): Unit = { + println("Assembly jar works") // scalastyle:ignore + } +} diff --git a/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala b/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala index 60185701..6a8551d2 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/glueExpressions.scala @@ -17,16 +17,14 @@ package io.projectglow.sql.expressions import io.projectglow.sql.util.{Rewrite, RewriteAfterResolution} - import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, ExpectsInputTypes, Expression, Generator, GenericInternalRow, GetStructField, ImplicitCastInputTypes, Literal, NamedExpression, UnaryExpression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, BinaryExpression, CaseWhen, Cast, CreateNamedStruct, Divide, EqualTo, Exp, ExpectsInputTypes, Expression, Factorial, Generator, GenericInternalRow, GetStructField, If, ImplicitCastInputTypes, LessThan, Literal, Log, Multiply, NamedExpression, Pi, Round, Subtract, UnaryExpression, Unevaluable} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ - import io.projectglow.SparkShim.newUnresolvedException /** @@ -231,3 +229,90 @@ object VectorToArray { new GenericArrayData(vectorType.deserialize(input).toArray) } } + +case class Comb(n: Expression, k: Expression) extends RewriteAfterResolution { + override def children: Seq[Expression] = Seq(n, k) + + override def rewrite: Expression = { + Cast( + Round( + Exp(Subtract(Subtract(LogFactorial(n), LogFactorial(k)), LogFactorial(Subtract(n, k)))), + Literal(0)), + LongType) + } + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { + copy(n = newChildren(0), k = newChildren(1)) + } +} + +/** + * Note: not user facing, approximate for n > 47 + */ +case class LogFactorial(n: Expression) extends RewriteAfterResolution { + override def children: Seq[Expression] = Seq(n) + + override def rewrite: Expression = { + Add + CaseWhen( + Seq( + (EqualTo(n, Literal(0)), Literal(0.0)), + (EqualTo(n, Literal(1)), Literal(0.0)), + (EqualTo(n, Literal(2)), Literal(0.693147180559945)), + (EqualTo(n, Literal(3)), Literal(1.7917594692280554)), + (EqualTo(n, Literal(4)), Literal(3.178053830347945)), + (EqualTo(n, Literal(5)), Literal(4.787491742782047)), + (EqualTo(n, Literal(6)), Literal(6.579251212010102)), + (EqualTo(n, Literal(7)), Literal(8.525161361065415)), + (EqualTo(n, Literal(8)), Literal(10.604602902745249)), + (EqualTo(n, Literal(9)), Literal(12.801827480081467)), + (EqualTo(n, Literal(10)), Literal(15.104412573075514)), + (EqualTo(n, Literal(11)), Literal(17.502307845873887)), + (EqualTo(n, Literal(12)), Literal(19.987214495661885)), + (EqualTo(n, Literal(13)), Literal(22.55216385312342)), + (EqualTo(n, Literal(14)), Literal(25.191221182738683)), + (EqualTo(n, Literal(15)), Literal(27.89927138384089)), + (EqualTo(n, Literal(16)), Literal(30.671860106080672)), + (EqualTo(n, Literal(17)), Literal(33.50507345013689)), + (EqualTo(n, Literal(18)), Literal(36.39544520803305)), + (EqualTo(n, Literal(19)), Literal(39.339884187199495)), + (EqualTo(n, Literal(20)), Literal(42.335616460753485)), + (EqualTo(n, Literal(21)), Literal(45.38013889847691)), + (EqualTo(n, Literal(22)), Literal(48.47118135183522)), + (EqualTo(n, Literal(23)), Literal(51.60667556776438)), + (EqualTo(n, Literal(24)), Literal(54.78472939811232)), + (EqualTo(n, Literal(25)), Literal(58.00360522298052)), + (EqualTo(n, Literal(26)), Literal(61.26170176100201)), + (EqualTo(n, Literal(27)), Literal(64.55753862700634)), + (EqualTo(n, Literal(28)), Literal(67.88974313718153)), + (EqualTo(n, Literal(29)), Literal(71.257038967168)), + (EqualTo(n, Literal(30)), Literal(74.65823634883017)), + (EqualTo(n, Literal(31)), Literal(78.0922235533153)), + (EqualTo(n, Literal(32)), Literal(81.55795945611503)), + (EqualTo(n, Literal(33)), Literal(85.05446701758152)), + (EqualTo(n, Literal(34)), Literal(88.58082754219768)), + (EqualTo(n, Literal(35)), Literal(92.1361756036871)), + (EqualTo(n, Literal(36)), Literal(95.7196945421432)), + (EqualTo(n, Literal(37)), Literal(99.33061245478743)), + (EqualTo(n, Literal(38)), Literal(102.96819861451381)), + (EqualTo(n, Literal(39)), Literal(106.63176026064346)), + (EqualTo(n, Literal(40)), Literal(110.32063971475738)), + (EqualTo(n, Literal(41)), Literal(114.03421178146169)), + (EqualTo(n, Literal(42)), Literal(117.77188139974507)), + (EqualTo(n, Literal(43)), Literal(121.53308151543864)), + (EqualTo(n, Literal(44)), Literal(125.3172711493569)), + (EqualTo(n, Literal(45)), Literal(129.12393363912722)), + (EqualTo(n, Literal(46)), Literal(132.95257503561632)), + (EqualTo(n, Literal(47)), Literal(136.80272263732635)) + ), + Some( + Add( + Subtract(Multiply(Add(n, Literal(0.5)), Log(n)), n), + Multiply(Literal(0.5), Log(Multiply(Literal(2), Pi()))))) + ) + } + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { + copy(n = newChildren(0)) + } + +} diff --git a/core/src/main/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitter.scala b/core/src/main/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitter.scala index 9109f9d3..8be3e64d 100644 --- a/core/src/main/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitter.scala +++ b/core/src/main/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitter.scala @@ -21,6 +21,7 @@ import htsjdk.variant.vcf.VCFHeaderLineCount import io.projectglow.common.GlowLogging import io.projectglow.common.VariantSchemas._ import io.projectglow.vcf.{InternalRowToVariantContextConverter, VCFSchemaInferrer} +import org.apache.commons.math3.util.CombinatoricsUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SQLUtils.structFieldsEqualExceptNullability import org.apache.spark.sql.functions._ @@ -92,7 +93,9 @@ private[projectglow] object VariantSplitter extends GlowLogging { lit(":"), col(startField.name) + 1, lit(":"), - concat_ws("/", col(refAlleleField.name), col(alternateAllelesField.name)) + array_join( + concat(array(col(refAlleleField.name)), col(alternateAllelesField.name)), + "/") ) ).otherwise(lit(null)) ) @@ -165,19 +168,11 @@ private[projectglow] object VariantSplitter extends GlowLogging { .fields .map(field => field.name -> - expr(s"transform(${genotypesFieldName}, g -> g.${field.name})")) + when( + col(splitFromMultiAllelicField.name), + expr(s"transform(${genotypesFieldName}, g -> g.${field.name})")).otherwise(array())) val withExtractedFields = variantDf.withColumns(extractedFields.toMap) - // register the udf that genotypes splitter uses - withExtractedFields - .sqlContext - .udf - .register( - "likelihoodSplitUdf", - (numAlleles: Int, ploidy: Int, alleleIdx: Int) => - refAltColexOrderIdxArray(numAlleles, ploidy, alleleIdx) - ) - // update pulled-out genotypes columns, zip them back together as the new genotypes column, // and drop the pulled-out columns // Note: In performance tests, it was seen that nested transform sql functions used below work twice faster if @@ -193,18 +188,15 @@ private[projectglow] object VariantSplitter extends GlowLogging { structFieldsEqualExceptNullability(phredLikelihoodsField, f) | structFieldsEqualExceptNullability(posteriorProbabilitiesField, f) => // update genotypes subfields that have colex order using the udf - f.name -> when( - col(splitFromMultiAllelicField.name), - expr(s"""transform(${f.name}, c -> + f.name -> + expr(s"""transform(${f.name}, c -> | filter( | transform( | c, (x, idx) -> | if ( | array_contains( - | likelihoodSplitUdf( - | size(${alternateAllelesField.name}) + 1, - | size(${callsField.name}[0]), - | $splitAlleleIdxFieldName + 1 + | transform(array_repeat(0, size(${callsField.name}[0]) + 1), (el, i) -> + | comb(size(${callsField.name}[0]) + $splitAlleleIdxFieldName + 1, size(${callsField.name}[0])) - comb(size(${callsField.name}[0]) + $splitAlleleIdxFieldName + 1 - i, size(${callsField.name}[0]) - i) | ), | idx | ), x, null @@ -213,90 +205,33 @@ private[projectglow] object VariantSplitter extends GlowLogging { | x -> !isnull(x) | ) | )""".stripMargin) - ).otherwise(col(f.name)) case f if structFieldsEqualExceptNullability(callsField, f) => // update GT calls subfield - f.name -> when( - col(splitFromMultiAllelicField.name), - expr( - s"transform(${f.name}, " + - s"c -> transform(c, x -> if(x == 0, x, if(x == $splitAlleleIdxFieldName + 1, 1, -1))))" - ) - ).otherwise(col(f.name)) + f.name -> + expr( + s"transform(${f.name}, " + + s"c -> transform(c, x -> if(x == 0, x, if(x == $splitAlleleIdxFieldName + 1, 1, -1))))" + ) case f if f.dataType.isInstanceOf[ArrayType] => // update any ArrayType field with number of elements equal to number of alt alleles - f.name -> when( - col(splitFromMultiAllelicField.name), - expr( - s"transform(${f.name}, c -> if(size(c) == size(${alternateAllelesField.name}) + 1," + - s" array(c[0], c[$splitAlleleIdxFieldName + 1]), null))" - ) - ).otherwise(col(f.name)) + f.name -> + expr( + s"transform(${f.name}, c -> if(size(c) == size(${alternateAllelesField.name}) + 1," + + s" array(c[0], c[$splitAlleleIdxFieldName + 1]), null))" + ) } withExtractedFields .withColumns(updatedColumns.toMap) - .withColumn(genotypesFieldName, arrays_zip(gSchema.get.fieldNames.map(col(_)): _*)) + .withColumn( + genotypesFieldName, + when( + col(splitFromMultiAllelicField.name), + arrays_zip(gSchema.get.fieldNames.map(col(_)): _*)).otherwise(col(genotypesFieldName))) .drop(gSchema.get.fieldNames: _*) } - - } - - /** - * Given the total number of (ref and alt) alleles (numAlleles), ploidy, and the index an alt allele of interest - * (altAlleleIdx), generates an array of indices of genotypes that only include the ref allele and/or that alt allele - * of interest in the colex ordering of all possible genotypes. The function is general and correctly calculates - * the index array for any given set of values for its arguments. - * - * Example: - * Assume numAlleles = 3 (say A,B,C), ploidy = 2, and altAlleleIdx = 2 (i.e., C) - * Therefore, colex ordering of all possible genotypes is: AA, AB, BB, AC, BC, CC - * and for example refAltColexOrderIdxArray(3, 2, 2) = Array(0, 3, 5) - * - * @param numAlleles : total number of alleles (ref and alt) - * @param ploidy : ploidy - * @param altAlleleIdx : index of alt allele of interest - * @return array of indices of genotypes that only include the ref allele and alt allele - * of interest in the colex ordering of all possible genotypes. - */ - @VisibleForTesting - private[splitmultiallelics] def refAltColexOrderIdxArray( - numAlleles: Int, - ploidy: Int, - altAlleleIdx: Int): Array[Int] = { - - if (ploidy < 1) { - throw new IllegalArgumentException("Ploidy must be at least 1.") - } - if (numAlleles < 2) { - throw new IllegalArgumentException( - "Number of alleles must be at least 2 (one REF and at least one ALT).") - } - if (altAlleleIdx > numAlleles - 1 || altAlleleIdx < 1) { - throw new IllegalArgumentException( - "Alternate allele index must be at least 1 and at most one less than number of alleles.") - } - - val idxArray = new Array[Int](ploidy + 1) - - // generate vector of elements at positions p+1,p,...,2 on the altAlleleIdx'th diagonal of Pascal's triangle - idxArray(0) = 0 - var i = 1 - idxArray(ploidy) = altAlleleIdx - while (i < ploidy) { - idxArray(ploidy - i) = idxArray(ploidy - i + 1) * (i + altAlleleIdx) / (i + 1) - i += 1 - } - - // calculate the cumulative vector - i = 1 - while (i <= ploidy) { - idxArray(i) = idxArray(i) + idxArray(i - 1) - i += 1 - } - idxArray } @VisibleForTesting diff --git a/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitterSuite.scala b/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitterSuite.scala index 004c149e..d51479ec 100644 --- a/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitterSuite.scala +++ b/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/VariantSplitterSuite.scala @@ -20,6 +20,7 @@ import io.projectglow.common.GlowLogging import io.projectglow.common.VariantSchemas._ import io.projectglow.sql.GlowBaseTest import io.projectglow.transformers.splitmultiallelics.VariantSplitter._ +import org.apache.commons.math3.util.CombinatoricsUtils import org.apache.spark.sql.functions._ class VariantSplitterSuite extends GlowBaseTest with GlowLogging { @@ -284,76 +285,73 @@ class VariantSplitterSuite extends GlowBaseTest with GlowLogging { } } - def testRefAltColexOrderIdxArray( - numAlleles: Int, - ploidy: Int, - altAlleleIdx: Int, - expected: Array[Int]): Unit = { - if (numAlleles < 2 || ploidy < 1 || altAlleleIdx < 1 || altAlleleIdx > numAlleles - 1) { - try { - refAltColexOrderIdxArray(numAlleles, ploidy, altAlleleIdx) - } catch { - case _: IllegalArgumentException => // Succeed - case _: Throwable => fail() - } - } else { - assert(refAltColexOrderIdxArray(numAlleles, ploidy, altAlleleIdx) === expected) - } - } - + case class ColexOrderTestCase(ploidy: Int, altAlleleIdx: Int, truth: Array[Int]) test("test refAltColexOrderIdxArray") { - // exception cases - testRefAltColexOrderIdxArray(1, 2, 1, Array()) - testRefAltColexOrderIdxArray(2, 0, 1, Array()) - testRefAltColexOrderIdxArray(2, 2, 0, Array()) - testRefAltColexOrderIdxArray(2, 2, 2, Array()) - - // valid cases - testRefAltColexOrderIdxArray(2, 1, 1, Array(0, 1)) - testRefAltColexOrderIdxArray(3, 1, 1, Array(0, 1)) - testRefAltColexOrderIdxArray(4, 1, 1, Array(0, 1)) - testRefAltColexOrderIdxArray(3, 1, 2, Array(0, 2)) - testRefAltColexOrderIdxArray(4, 1, 2, Array(0, 2)) - testRefAltColexOrderIdxArray(4, 1, 3, Array(0, 3)) - testRefAltColexOrderIdxArray(5, 1, 4, Array(0, 4)) - - testRefAltColexOrderIdxArray(2, 2, 1, Array(0, 1, 2)) - testRefAltColexOrderIdxArray(3, 2, 1, Array(0, 1, 2)) - testRefAltColexOrderIdxArray(4, 2, 1, Array(0, 1, 2)) - testRefAltColexOrderIdxArray(3, 2, 2, Array(0, 3, 5)) - testRefAltColexOrderIdxArray(4, 2, 2, Array(0, 3, 5)) - testRefAltColexOrderIdxArray(4, 2, 3, Array(0, 6, 9)) - testRefAltColexOrderIdxArray(5, 2, 4, Array(0, 10, 14)) - testRefAltColexOrderIdxArray(6, 2, 5, Array(0, 15, 20)) - testRefAltColexOrderIdxArray(7, 2, 6, Array(0, 21, 27)) - testRefAltColexOrderIdxArray(8, 2, 7, Array(0, 28, 35)) - testRefAltColexOrderIdxArray(9, 2, 8, Array(0, 36, 44)) - testRefAltColexOrderIdxArray(10, 2, 9, Array(0, 45, 54)) - testRefAltColexOrderIdxArray(11, 2, 10, Array(0, 55, 65)) - - testRefAltColexOrderIdxArray(2, 3, 1, Array(0, 1, 2, 3)) - testRefAltColexOrderIdxArray(3, 3, 1, Array(0, 1, 2, 3)) - testRefAltColexOrderIdxArray(4, 3, 1, Array(0, 1, 2, 3)) - testRefAltColexOrderIdxArray(3, 3, 2, Array(0, 4, 7, 9)) - testRefAltColexOrderIdxArray(4, 3, 2, Array(0, 4, 7, 9)) - testRefAltColexOrderIdxArray(4, 3, 3, Array(0, 10, 16, 19)) - testRefAltColexOrderIdxArray(5, 3, 4, Array(0, 20, 30, 34)) - - testRefAltColexOrderIdxArray(2, 4, 1, Array(0, 1, 2, 3, 4)) - testRefAltColexOrderIdxArray(3, 4, 1, Array(0, 1, 2, 3, 4)) - testRefAltColexOrderIdxArray(4, 4, 1, Array(0, 1, 2, 3, 4)) - testRefAltColexOrderIdxArray(3, 4, 2, Array(0, 5, 9, 12, 14)) - testRefAltColexOrderIdxArray(4, 4, 2, Array(0, 5, 9, 12, 14)) - testRefAltColexOrderIdxArray(4, 4, 3, Array(0, 15, 25, 31, 34)) - testRefAltColexOrderIdxArray(5, 4, 4, Array(0, 35, 55, 65, 69)) - testRefAltColexOrderIdxArray(6, 4, 4, Array(0, 35, 55, 65, 69)) - testRefAltColexOrderIdxArray(6, 4, 5, Array(0, 70, 105, 120, 125)) - - testRefAltColexOrderIdxArray(6, 5, 1, Array(0, 1, 2, 3, 4, 5)) - testRefAltColexOrderIdxArray(6, 5, 2, Array(0, 6, 11, 15, 18, 20)) - testRefAltColexOrderIdxArray(6, 5, 3, Array(0, 21, 36, 46, 52, 55)) - testRefAltColexOrderIdxArray(6, 5, 4, Array(0, 56, 91, 111, 121, 125)) - testRefAltColexOrderIdxArray(6, 5, 5, Array(0, 126, 196, 231, 246, 251)) + val cases = Seq( + ColexOrderTestCase(1, 1, Array(0, 1)), + ColexOrderTestCase(1, 1, Array(0, 1)), + ColexOrderTestCase(1, 1, Array(0, 1)), + ColexOrderTestCase(1, 2, Array(0, 2)), + ColexOrderTestCase(1, 2, Array(0, 2)), + ColexOrderTestCase(1, 3, Array(0, 3)), + ColexOrderTestCase(1, 4, Array(0, 4)), + ColexOrderTestCase(2, 1, Array(0, 1, 2)), + ColexOrderTestCase(2, 1, Array(0, 1, 2)), + ColexOrderTestCase(2, 1, Array(0, 1, 2)), + ColexOrderTestCase(2, 2, Array(0, 3, 5)), + ColexOrderTestCase(2, 2, Array(0, 3, 5)), + ColexOrderTestCase(2, 3, Array(0, 6, 9)), + ColexOrderTestCase(2, 4, Array(0, 10, 14)), + ColexOrderTestCase(2, 5, Array(0, 15, 20)), + ColexOrderTestCase(2, 6, Array(0, 21, 27)), + ColexOrderTestCase(2, 7, Array(0, 28, 35)), + ColexOrderTestCase(2, 8, Array(0, 36, 44)), + ColexOrderTestCase(2, 9, Array(0, 45, 54)), + ColexOrderTestCase(2, 10, Array(0, 55, 65)), + ColexOrderTestCase(3, 1, Array(0, 1, 2, 3)), + ColexOrderTestCase(3, 1, Array(0, 1, 2, 3)), + ColexOrderTestCase(3, 1, Array(0, 1, 2, 3)), + ColexOrderTestCase(3, 2, Array(0, 4, 7, 9)), + ColexOrderTestCase(3, 2, Array(0, 4, 7, 9)), + ColexOrderTestCase(3, 3, Array(0, 10, 16, 19)), + ColexOrderTestCase(3, 4, Array(0, 20, 30, 34)), + ColexOrderTestCase(4, 1, Array(0, 1, 2, 3, 4)), + ColexOrderTestCase(4, 1, Array(0, 1, 2, 3, 4)), + ColexOrderTestCase(4, 1, Array(0, 1, 2, 3, 4)), + ColexOrderTestCase(4, 2, Array(0, 5, 9, 12, 14)), + ColexOrderTestCase(4, 2, Array(0, 5, 9, 12, 14)), + ColexOrderTestCase(4, 3, Array(0, 15, 25, 31, 34)), + ColexOrderTestCase(4, 4, Array(0, 35, 55, 65, 69)), + ColexOrderTestCase(4, 4, Array(0, 35, 55, 65, 69)), + ColexOrderTestCase(4, 5, Array(0, 70, 105, 120, 125)), + ColexOrderTestCase(5, 1, Array(0, 1, 2, 3, 4, 5)), + ColexOrderTestCase(5, 2, Array(0, 6, 11, 15, 18, 20)), + ColexOrderTestCase(5, 3, Array(0, 21, 36, 46, 52, 55)), + ColexOrderTestCase(5, 4, Array(0, 56, 91, 111, 121, 125)), + ColexOrderTestCase(5, 5, Array(0, 126, 196, 231, 246, 251)) + ) + val df = spark + .createDataFrame(cases) + .withColumn( + "glow", + expr( + "transform(array_repeat(0, ploidy + 1), (el, i) -> comb(ploidy + altAlleleIdx, ploidy) - comb(ploidy + altAlleleIdx - i, ploidy - i))") + ) + .where("glow != truth") + assert(df.count() == 0) } + case class BinomialTestCase(n: Int, k: Int, coeff: Long) + test("binomial coefficient function") { + val cases = Range(0, 45).flatMap { n => + Range.inclusive(0, n).map { k => + BinomialTestCase(n, k, CombinatoricsUtils.binomialCoefficient(n, k)) + } + } + val df = spark + .createDataFrame(cases) + .withColumn("glow", expr("comb(n, k)")) + .where("glow != coeff") + assert(df.count() == 0) + } } diff --git a/functions.yml b/functions.yml index abfb9bc1..5708f9ce 100644 --- a/functions.yml +++ b/functions.yml @@ -484,3 +484,17 @@ gwas_functions: - name: genotypes doc: An array of genotype structs with ``calls`` field returns: An array of integers containing the number of alternate alleles in each call array + +# These functions are intended to only be called from Glow code. They are available in SQL but not in the Scala +# or Python APIs. +private: + functions: + - name: comb + doc: Compute binomial coefficient + since: 2.1.0 + expr_class: io.projectglow.sql.expressions.Comb + args: + - name: n + doc: n + - name: k + doc: k diff --git a/project/plugins.sbt b/project/plugins.sbt index 9dcea6ac..58b5c16d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,4 +1,4 @@ -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.2.0") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.1.5") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.7") addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.10.0") diff --git a/python/render_template.py b/python/render_template.py index e6c618a0..857c88af 100755 --- a/python/render_template.py +++ b/python/render_template.py @@ -151,5 +151,6 @@ def render_template(template_path, output_path, **kwargs): args = parser.parse_args() function_groups = yaml.load(open(FUNCTIONS_YAML), Loader=yaml.SafeLoader) + del function_groups['private'] groups_to_render = prepare_definitions(function_groups) render_template(args.template_path, args.output_path, groups=groups_to_render)