diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index c05b26187..b45b91920 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -208,11 +208,15 @@ class Coloc(ColocalisationMethodInterface): Attributes: PSEUDOCOUNT (float): Pseudocount to avoid log(0). Defaults to 1e-10. + OVERLAP_SIZE_CUTOFF (int): Minimum number of overlapping variants bfore filtering. Defaults to 5. + POSTERIOR_CUTOFF (float): Minimum overlapping Posterior probability cutoff for small overlaps. Defaults to 0.5. """ METHOD_NAME: str = "COLOC" METHOD_METRIC: str = "h4" PSEUDOCOUNT: float = 1e-10 + OVERLAP_SIZE_CUTOFF: int = 5 + POSTERIOR_CUTOFF: float = 0.5 @staticmethod def _get_posteriors(all_bfs: NDArray[np.float64]) -> DenseVector: @@ -277,7 +281,15 @@ def colocalise( ) .select("*", "statistics.*") # Before summing log_BF columns nulls need to be filled with 0: - .fillna(0, subset=["left_logBF", "right_logBF"]) + .fillna( + 0, + subset=[ + "left_logBF", + "right_logBF", + "left_posteriorProbability", + "right_posteriorProbability", + ], + ) # Sum of log_BFs for each pair of signals .withColumn( "sum_log_bf", @@ -305,9 +317,18 @@ def colocalise( fml.array_to_vector(f.collect_list(f.col("right_logBF"))).alias( "right_logBF" ), + fml.array_to_vector( + f.collect_list(f.col("left_posteriorProbability")) + ).alias("left_posteriorProbability"), + fml.array_to_vector( + f.collect_list(f.col("right_posteriorProbability")) + ).alias("right_posteriorProbability"), fml.array_to_vector(f.collect_list(f.col("sum_log_bf"))).alias( "sum_log_bf" ), + f.collect_list(f.col("tagVariantSource")).alias( + "tagVariantSourceList" + ), ) .withColumn("logsum1", logsum(f.col("left_logBF"))) .withColumn("logsum2", logsum(f.col("right_logBF"))) @@ -327,10 +348,39 @@ def colocalise( # h3 .withColumn("sumlogsum", f.col("logsum1") + f.col("logsum2")) .withColumn("max", f.greatest("sumlogsum", "logsum12")) + .withColumn( + "anySnpBothSidesHigh", + f.aggregate( + f.transform( + f.arrays_zip( + fml.vector_to_array(f.col("left_posteriorProbability")), + fml.vector_to_array( + f.col("right_posteriorProbability") + ), + f.col("tagVariantSourceList"), + ), + # row["0"] = left PP, row["1"] = right PP, row["tagVariantSourceList"] + lambda row: f.when( + (row["tagVariantSourceList"] == "both") + & (row["0"] > Coloc.POSTERIOR_CUTOFF) + & (row["1"] > Coloc.POSTERIOR_CUTOFF), + 1.0, + ).otherwise(0.0), + ), + f.lit(0.0), + lambda acc, x: acc + x, + ) + > 0, # True if sum of these 1.0's > 0 + ) + .filter( + (f.col("numberColocalisingVariants") > Coloc.OVERLAP_SIZE_CUTOFF) + | (f.col("anySnpBothSidesHigh")) + ) .withColumn( "logdiff", f.when( - f.col("sumlogsum") == f.col("logsum12"), Coloc.PSEUDOCOUNT + (f.col("sumlogsum") == f.col("logsum12")), + Coloc.PSEUDOCOUNT, ).otherwise( f.col("max") + f.log( @@ -382,6 +432,10 @@ def colocalise( "lH2bf", "lH3bf", "lH4bf", + "left_posteriorProbability", + "right_posteriorProbability", + "tagVariantSourceList", + "anySnpBothSidesHigh", ) .withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME)) .join( diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index 5b05d724b..78a66f732 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -43,6 +43,8 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: "right_logBF": 10.5, "left_beta": 0.1, "right_beta": 0.2, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, }, }, ], @@ -57,7 +59,7 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: }, ], ), - # associations with multiple overlapping SNPs + # Case with mismatched posterior probabilities: ( # observed overlap [ @@ -68,10 +70,12 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: "chromosome": "1", "tagVariantId": "snp1", "statistics": { - "left_logBF": 10.3, + "left_logBF": 1.2, "right_logBF": 10.5, - "left_beta": 0.1, + "left_beta": 0.001, "right_beta": 0.2, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.92, }, }, { @@ -82,23 +86,177 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: "tagVariantId": "snp2", "statistics": { "left_logBF": 10.3, - "right_logBF": 10.5, + "right_logBF": 3.8, "left_beta": 0.3, - "right_beta": 0.5, + "right_beta": 0.005, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.01, + }, + }, + ], + # expected coloc + [], + ), + # Case of an overlap with significant PP overlap: + ( + # observed overlap + [ + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": { + "left_logBF": 10.2, + "right_logBF": 10.5, + "left_beta": 0.5, + "right_beta": 0.2, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": { + "left_logBF": 1.2, + "right_logBF": 3.8, + "left_beta": 0.003, + "right_beta": 0.005, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, }, }, ], # expected coloc [ { - "h0": 4.6230151407950416e-5, - "h1": 2.749086942648107e-4, - "h2": 3.357742374172504e-4, - "h3": 9.983447421747411e-4, - "h4": 0.9983447421747356, + "h0": 1.02277006860577e-4, + "h1": 2.7519169183135977e-4, + "h2": 3.718812819512325e-4, + "h3": 1.3533048074295033e-6, + "h4": 0.9992492967145488, }, ], ), + # Case where the overlap source is ["left", "both", "both"]: + ( + # observed overlap + [ + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": { + "left_logBF": 1.2, + "right_logBF": None, + "left_beta": 0.003, + "right_beta": None, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": { + "left_logBF": 1.2, + "right_logBF": 3.8, + "left_beta": 0.003, + "right_beta": 0.005, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp3", + "statistics": { + "left_logBF": 10.2, + "right_logBF": 10.5, + "left_beta": 0.5, + "right_beta": 0.2, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, + }, + }, + ], + # expected coloc + [ + { + "h0": 1.02277006860577e-4, + "h1": 2.752255943423052e-4, + "h2": 3.718914358059273e-4, + "h3": 1.5042926116520848e-6, + "h4": 0.9992491016906891, + }, + ], + ), + # Case where PPs are high on the left, but low on the right: + ( + # observed overlap + [ + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": { + "left_logBF": 1.2, + "right_logBF": None, + "left_beta": 0.003, + "right_beta": None, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": { + "left_logBF": 1.2, + "right_logBF": 3.8, + "left_beta": 0.003, + "right_beta": 0.005, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp3", + "statistics": { + "left_logBF": 10.2, + "right_logBF": 10.5, + "left_beta": 0.5, + "right_beta": 0.2, + "left_posteriorProbability": 0.36, + "right_posteriorProbability": 0.92, + }, + }, + ], + # expected coloc + [], + ), ], ) def test_coloc_semantic( @@ -111,24 +269,45 @@ def test_coloc_semantic( _df=spark.createDataFrame(observed_data, schema=StudyLocusOverlap.get_schema()), _schema=StudyLocusOverlap.get_schema(), ) - observed_coloc_pdf = ( - Coloc.colocalise(observed_overlap) - .df.select("h0", "h1", "h2", "h3", "h4") - .toPandas() - ) - expected_coloc_pdf = ( - spark.createDataFrame(expected_data) - .select("h0", "h1", "h2", "h3", "h4") - .toPandas() - ) - assert_frame_equal( - observed_coloc_pdf, - expected_coloc_pdf, - check_exact=False, - check_dtype=True, + observed_coloc_df = Coloc.colocalise(observed_overlap).df + + # Define schema for the expected DataFrame + result_schema = StructType( + [ + StructField("h0", DoubleType(), True), + StructField("h1", DoubleType(), True), + StructField("h2", DoubleType(), True), + StructField("h3", DoubleType(), True), + StructField("h4", DoubleType(), True), + ] ) + if not expected_data: + expected_coloc_df = spark.createDataFrame([], schema=result_schema) + else: + expected_coloc_df = spark.createDataFrame(expected_data, schema=result_schema) + + if observed_coloc_df.rdd.isEmpty(): + observed_coloc_df = spark.createDataFrame([], schema=result_schema) + + observed_coloc_df = observed_coloc_df.select("h0", "h1", "h2", "h3", "h4") + + observed_coloc_pdf = observed_coloc_df.toPandas() + expected_coloc_pdf = expected_coloc_df.toPandas() + + if expected_coloc_pdf.empty: + assert ( + observed_coloc_pdf.empty + ), f"Expected an empty DataFrame, but got:\n{observed_coloc_pdf}" + else: + assert_frame_equal( + observed_coloc_pdf, + expected_coloc_pdf, + check_exact=False, + check_dtype=True, + ) + def test_coloc_no_logbf( spark: SparkSession, @@ -151,8 +330,8 @@ def test_coloc_no_logbf( "right_logBF": None, "left_beta": 0.1, "right_beta": 0.2, - "left_posteriorProbability": None, - "right_posteriorProbability": None, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, }, # irrelevant for COLOC } ], @@ -212,8 +391,8 @@ def test_coloc_no_betas(spark: SparkSession) -> None: "right_logBF": 10.3, "left_beta": None, "right_beta": None, - "left_posteriorProbability": None, - "right_posteriorProbability": None, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, }, # irrelevant for COLOC } ],