From 58b0e6bfcd46f286dce89b5bb6df262a39d45034 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 14 Aug 2024 11:27:52 -0400 Subject: [PATCH] implement fast fisher exact test pvalue --- hail/python/hail/expr/functions.py | 11 +- .../expr/ir/functions/MathFunctions.scala | 28 ++++- .../main/scala/is/hail/stats/package.scala | 108 ++++++++++-------- .../is/hail/stats/FisherExactTestSuite.scala | 59 +++++++++- 4 files changed, 148 insertions(+), 58 deletions(-) diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index bce627436f4..dfc1b7ebbd2 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -1095,8 +1095,8 @@ def exp(x) -> Float64Expression: return _func("exp", tfloat64, x) -@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32) -def fisher_exact_test(c1, c2, c3, c4) -> StructExpression: +@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32, _pvalue_only=bool) +def fisher_exact_test(c1, c2, c3, c4, _pvalue_only=False) -> StructExpression: """Calculates the p-value, odds ratio, and 95% confidence interval using Fisher's exact test for a 2x2 table. @@ -1138,8 +1138,11 @@ def fisher_exact_test(c1, c2, c3, c4) -> StructExpression: `ci_95_lower (:py:data:`.tfloat64`), and `ci_95_upper` (:py:data:`.tfloat64`). """ - ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64) - return _func("fisher_exact_test", ret_type, c1, c2, c3, c4) + if _pvalue_only: + return struct(p_value=_func("fisher_exact_test_pvalue_only", tfloat64, c1, c2, c3, c4)) + else: + ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64) + return _func("fisher_exact_test", ret_type, c1, c2, c3, c4) @typecheck(x=expr_oneof(expr_float32, expr_float64, expr_ndarray(expr_float64))) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala index bacc2c50ac4..59c301bb610 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala @@ -398,8 +398,7 @@ object MathFunctions extends RegistryFunctions { fetStruct.virtualType, (_, _, _, _, _) => fetStruct.sType, ) { case (r, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) => - val res = cb.newLocal[Array[Double]]( - "fisher_exact_test_res", + val res = cb.memoize[Array[Double]]( Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]]( statsPackageClass, "fisherExactTest", @@ -407,7 +406,7 @@ object MathFunctions extends RegistryFunctions { b.value, c.value, d.value, - ), + ) ) fetStruct.constructFromFields( @@ -423,6 +422,29 @@ object MathFunctions extends RegistryFunctions { ) } + // FIXME: delete when PruneDeadField can optimize fisher_exact_test when only + // the pvalue is used from the result struct + registerSCode4( + "fisher_exact_test_pvalue_only", + TInt32, + TInt32, + TInt32, + TInt32, + TFloat64, + (_, _, _, _, _) => SFloat64, + ) { case (_, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) => + primitive(cb.memoize[Double]( + Code.invokeScalaObject4[Int, Int, Int, Int, Double]( + statsPackageClass, + "fisherExactTestPValueOnly", + a.value, + b.value, + c.value, + d.value, + ) + )) + } + registerSCode4( "chi_squared_test", TInt32, diff --git a/hail/src/main/scala/is/hail/stats/package.scala b/hail/src/main/scala/is/hail/stats/package.scala index 2b87f1d320f..9afea6ef117 100644 --- a/hail/src/main/scala/is/hail/stats/package.scala +++ b/hail/src/main/scala/is/hail/stats/package.scala @@ -2,11 +2,12 @@ package is.hail import is.hail.types.physical.{PCanonicalStruct, PFloat64} import is.hail.utils._ - import net.sourceforge.jdistlib.{Beta, ChiSquare, NonCentralChiSquare, Normal, Poisson} import net.sourceforge.jdistlib.disttest.{DistributionTest, TestKind} import org.apache.commons.math3.distribution.HypergeometricDistribution +import scala.annotation.tailrec + package object stats { def uniroot(fn: Double => Double, min: Double, max: Double, tolerance: Double = 1.220703e-4) @@ -162,45 +163,28 @@ package object stats { ) def fisherExactTest(a: Int, b: Int, c: Int, d: Int): Array[Double] = - fisherExactTest(a, b, c, d, 1.0, 0.95, "two.sided") - - def fisherExactTest( - a: Int, - b: Int, - c: Int, - d: Int, - oddsRatio: Double = 1d, - confidenceLevel: Double = 0.95, - alternative: String = "two.sided", - ): Array[Double] = { + fisherExactTest(a, b, c, d, 0.95) + def fisherExactTest(a: Int, b: Int, c: Int, d: Int, confidenceLevel: Double): Array[Double] = { if (!(a >= 0 && b >= 0 && c >= 0 && d >= 0)) fatal(s"fisher_exact_test: all arguments must be non-negative, got $a, $b, $c, $d") if (confidenceLevel < 0d || confidenceLevel > 1d) fatal("Confidence level must be between 0 and 1") - if (oddsRatio < 0d) - fatal("Odds ratio must be non-negative") - - if (alternative != "greater" && alternative != "less" && alternative != "two.sided") - fatal("Did not recognize test type string. Use one of greater, less, two.sided") - val popSize = a + b + c + d - val numSuccessPopulation = a + c - val sampleSize = a + b + val nGood = a + c + val nSample = a + b val numSuccessSample = a - if ( - !(popSize > 0 && sampleSize > 0 && sampleSize < popSize && numSuccessPopulation > 0 && numSuccessPopulation < popSize) - ) + if (!(popSize > 0 && nSample > 0 && nSample < popSize && nGood > 0 && nGood < popSize)) return Array(Double.NaN, Double.NaN, Double.NaN, Double.NaN) val low = math.max(0, (a + b) - (b + d)) val high = math.min(a + b, a + c) val support = (low to high).toArray - val hgd = new HypergeometricDistribution(null, popSize, numSuccessPopulation, sampleSize) + val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample) val epsilon = 2.220446e-16 def dhyper(k: Int, logProb: Boolean): Double = @@ -320,36 +304,68 @@ package object stats { } } - val pvalue: Double = (alternative: @unchecked) match { - case "less" => pnhyper(numSuccessSample, oddsRatio) - case "greater" => pnhyper(numSuccessSample, oddsRatio, upper_tail = true) - case "two.sided" => - if (oddsRatio == 0) - if (low == numSuccessSample) 1d else 0d - else if (oddsRatio == Double.PositiveInfinity) - if (high == numSuccessSample) 1d else 0d - else { - val relErr = 1d + 1e-7 - val d = dnhyper(oddsRatio) - d.filter(_ <= d(numSuccessSample - low) * relErr).sum - } - } - - assert(pvalue >= 0d && pvalue <= 1.000000000002) + val pvalue = fisherExactTestPValueOnly(a, b, c, d) val oddsRatioEstimate = mle(numSuccessSample) - val confInterval = alternative match { - case "less" => (0d, ncpUpper(numSuccessSample, 1 - confidenceLevel)) - case "greater" => (ncpLower(numSuccessSample, 1 - confidenceLevel), Double.PositiveInfinity) - case "two.sided" => - val alpha = (1 - confidenceLevel) / 2d - (ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha)) + val confInterval = { + val alpha = (1 - confidenceLevel) / 2d + (ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha)) } Array(pvalue, oddsRatioEstimate, confInterval._1, confInterval._2) } + def fisherExactTestPValueOnly(a: Int, b: Int, c: Int, d: Int): Double = { + val popSize = a + b + c + d + val nGood = a + c + val nSample = a + b + val numSuccessSample = a + + val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample) + + // Returns i in [start, end] such that a([start, i)) is <= d, and a([i, end)) is > d + @tailrec def upperBoundIncreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = { + if (start >= end) return start + val mid = (start + end) >>> 1 + val elt = a(mid) + if (elt <= d) upperBoundIncreasing(a, d, mid + 1, end) + else upperBoundIncreasing(a, d, start, mid) + } + + // Returns i in [start, end] such that a([start, i)) is > d, and a([i, end)) is <= d + @tailrec def lowerBoundDecreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = { + if (start >= end) return start + val mid = (start + end) >>> 1 + val elt = a(mid) + if (elt > d) lowerBoundDecreasing(a, d, mid + 1, end) + else lowerBoundDecreasing(a, d, start, mid) + } + + val epsilon = 1e-14 + val gamma = 1 + epsilon + + val mode = ((nSample + 1.0) * (nGood + 1.0) / (popSize + 2.0)).toInt + val pexact = hgd.probability(numSuccessSample) + val pmode = hgd.probability(mode) + + val pvalue = if (math.abs(pexact - pmode) / math.max(pexact, pmode) <= epsilon) { + 1.0 + } else if (numSuccessSample < mode) { + val plower = hgd.cumulativeProbability(numSuccessSample) + val bound = lowerBoundDecreasing(hgd.probability, pexact * gamma, mode + 1, nSample + 1) + plower + hgd.upperCumulativeProbability(bound) + } else { + val pupper = hgd.upperCumulativeProbability(numSuccessSample) + val bound = upperBoundIncreasing(hgd.probability, pexact * gamma, 0, mode) + pupper + hgd.cumulativeProbability(bound - 1) + } + + assert(pvalue >= 0d && pvalue <= 1.000000000002) + + pvalue + } + def dnorm(x: Double, mu: Double, sigma: Double, logP: Boolean): Double = Normal.density(x, mu, sigma, logP) diff --git a/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala b/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala index bf5d68aa584..2fe599b766b 100644 --- a/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala +++ b/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala @@ -1,7 +1,7 @@ package is.hail.stats import is.hail.HailSuite - +import is.hail.utils.D_== import org.testng.annotations.Test class FisherExactTestSuite extends HailSuite { @@ -14,9 +14,58 @@ class FisherExactTestSuite extends HailSuite { val result = fisherExactTest(a, b, c, d) - assert(math.abs(result(0) - 0.2828) < 1e-4) - assert(math.abs(result(1) - 0.4754059) < 1e-4) - assert(math.abs(result(2) - 0.122593) < 1e-4) - assert(math.abs(result(3) - 1.597972) < 1e-4) + assert(D_==(result(0), 0.2828, 1e-4)) + assert(D_==(result(1), 0.4754059, 1e-4)) + assert(D_==(result(2), 0.122593, 1e-4)) + assert(D_==(result(3), 1.597972, 1e-4)) + } + + @Test def testPvalue2(): Unit = { + val a = 10 + val b = 5 + val c = 90 + val d = 95 + + val result = fisherExactTest(a, b, c, d) + + assert(D_==(result(0), 0.2828, 1e-4)) + } + + @Test def test_basic(): Unit = { + // test cases taken from scipy/stats/tests/test_stats.py + var res = fisherExactTestPValueOnly(14500, 20000, 30000, 40000) + assert(D_==(res, 0.01106, 1e-3)) + res = fisherExactTestPValueOnly(100, 2, 1000, 5) + assert(D_==(res, 0.1301, 1e-3)) + res = fisherExactTestPValueOnly(2, 7, 8, 2) + assert(D_==(res, 0.0230141, 1e-5)) + res = fisherExactTestPValueOnly(5, 1, 10, 10) + assert(D_==(res, 0.1973244, 1e-6)) + res = fisherExactTestPValueOnly(5, 15, 20, 20) + assert(D_==(res, 0.0958044, 1e-6)) + res = fisherExactTestPValueOnly(5, 16, 20, 25) + assert(D_==(res, 0.1725862, 1e-5)) + res = fisherExactTestPValueOnly(10, 5, 10, 1) + assert(D_==(res, 0.1973244, 1e-6)) + res = fisherExactTestPValueOnly(5, 0, 1, 4) + assert(D_==(res, 0.04761904, 1e-6)) + res = fisherExactTestPValueOnly(0, 1, 3, 2) + assert(res == 1.0) + res = fisherExactTestPValueOnly(0, 2, 6, 4) + assert(D_==(res, 0.4545454545)) + res = fisherExactTestPValueOnly(2, 7, 8, 2) + assert(D_==(res, 0.0230141, 1e-5)) + + res = fisherExactTestPValueOnly(6, 37, 108, 200) + assert(D_==(res, 0.005092697748126)) + res = fisherExactTestPValueOnly(22, 0, 0, 102) + assert(D_==(res, 7.175066786244549e-25)) + res = fisherExactTestPValueOnly(94, 48, 3577, 16988) + assert(D_==(res, 2.069356340993818e-37)) + res = fisherExactTestPValueOnly(5829225, 5692693, 5760959, 5760959) + assert(res <= 1e-170) + for ((a, b, c, d) <- Array((0, 0, 5, 10), (5, 10, 0, 0), (0, 5, 0, 10), (5, 0, 10, 0))) { + assert(fisherExactTestPValueOnly(a, b, c, d) == 1.0) + } } }