20
20
Tests for the data files.
21
21
"""
22
22
import json
23
+ import logging
23
24
import sys
24
25
import tempfile
26
+ import warnings
25
27
26
28
import msprime
27
29
import numcodecs
@@ -627,14 +629,12 @@ def test_missing_ancestral_allele(tmp_path):
627
629
628
630
629
631
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
630
- def test_ancestral_missingness (tmp_path ):
632
+ def test_deliberate_ancestral_missingness (tmp_path ):
631
633
ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
632
634
ds = sgkit .load_dataset (zarr_path )
633
635
ancestral_allele = ds .variant_ancestral_allele .values
634
636
ancestral_allele [0 ] = "N"
635
- ancestral_allele [11 ] = "-"
636
- ancestral_allele [12 ] = "💩"
637
- ancestral_allele [15 ] = "💩"
637
+ ancestral_allele [1 ] = "n"
638
638
ds = ds .drop_vars (["variant_ancestral_allele" ])
639
639
sgkit .save_dataset (ds , str (zarr_path ) + ".tmp" )
640
640
tsutil .add_array_to_dataset (
@@ -644,15 +644,56 @@ def test_ancestral_missingness(tmp_path):
644
644
["variants" ],
645
645
)
646
646
ds = sgkit .load_dataset (str (zarr_path ) + ".tmp" )
647
+ with warnings .catch_warnings ():
648
+ warnings .simplefilter ("error" ) # No warning raised if AA deliberately missing
649
+ sd = tsinfer .VariantData (str (zarr_path ) + ".tmp" , "variant_ancestral_allele" )
650
+ inf_ts = tsinfer .infer (sd )
651
+ for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
652
+ if i in [0 , 1 ]:
653
+ assert inf_var .site .metadata == {"inference_type" : "parsimony" }
654
+ else :
655
+ assert inf_var .site .ancestral_state == var .site .ancestral_state
656
+
657
+
658
+ @pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
659
+ def test_ancestral_missing_warning (tmp_path ):
660
+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
661
+ ds = sgkit .load_dataset (zarr_path )
662
+ anc_state = ds .variant_ancestral_allele .values
663
+ anc_state [0 ] = "N"
664
+ anc_state [11 ] = "-"
665
+ anc_state [12 ] = "💩"
666
+ anc_state [15 ] = "💩"
647
667
with pytest .warns (
648
668
UserWarning ,
649
669
match = r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2" ,
650
670
):
651
- sd = tsinfer .VariantData (str ( zarr_path ) + ".tmp" , "variant_ancestral_allele" )
652
- inf_ts = tsinfer .infer (sd )
671
+ vdata = tsinfer .VariantData (zarr_path , anc_state )
672
+ inf_ts = tsinfer .infer (vdata )
653
673
for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
654
674
if i in [0 , 11 , 12 , 15 ]:
655
675
assert inf_var .site .metadata == {"inference_type" : "parsimony" }
676
+ assert inf_var .site .ancestral_state in var .site .alleles
677
+ else :
678
+ assert inf_var .site .ancestral_state == var .site .ancestral_state
679
+
680
+
681
+ def test_ancestral_missing_info (tmp_path , caplog ):
682
+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
683
+ ds = sgkit .load_dataset (zarr_path )
684
+ anc_state = ds .variant_ancestral_allele .values
685
+ anc_state [0 ] = "N"
686
+ anc_state [11 ] = "N"
687
+ anc_state [12 ] = "n"
688
+ anc_state [15 ] = "n"
689
+ with caplog .at_level (logging .INFO ):
690
+ vdata = tsinfer .VariantData (zarr_path , anc_state )
691
+ assert f"4 sites ({ 4 / ts .num_sites * 100 :.2f} %) were deliberately " in caplog .text
692
+ inf_ts = tsinfer .infer (vdata )
693
+ for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
694
+ if i in [0 , 11 , 12 , 15 ]:
695
+ assert inf_var .site .metadata == {"inference_type" : "parsimony" }
696
+ assert inf_var .site .ancestral_state in var .site .alleles
656
697
else :
657
698
assert inf_var .site .ancestral_state == var .site .ancestral_state
658
699
@@ -670,6 +711,25 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
670
711
671
712
672
713
class TestVariantDataErrors :
714
+ @staticmethod
715
+ def simulate_genotype_call_dataset (* args , ** kwargs ):
716
+ # roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
717
+ # duplicate alleles are created. Doesn't need to be efficient: just for testing
718
+ if "seed" not in kwargs :
719
+ kwargs ["seed" ] = 123
720
+ ds = sgkit .simulate_genotype_call_dataset (* args , ** kwargs )
721
+ variant_alleles = ds ["variant_allele" ].values
722
+ allowed_alleles = np .array (
723
+ ["A" , "T" , "C" , "G" , "N" ], dtype = variant_alleles .dtype
724
+ )
725
+ for row in range (len (variant_alleles )):
726
+ alleles = variant_alleles [row ]
727
+ if len (set (alleles )) != len (alleles ):
728
+ # Just use a set that we know is unique
729
+ variant_alleles [row ] = allowed_alleles [0 : len (alleles )]
730
+ ds ["variant_allele" ] = ds ["variant_allele" ].dims , variant_alleles
731
+ return ds
732
+
673
733
def test_bad_zarr_spec (self ):
674
734
ds = zarr .group ()
675
735
ds ["call_genotype" ] = zarr .array (np .zeros (10 , dtype = np .int8 ))
@@ -680,7 +740,7 @@ def test_bad_zarr_spec(self):
680
740
681
741
def test_missing_phase (self , tmp_path ):
682
742
path = tmp_path / "data.zarr"
683
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
743
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
684
744
sgkit .save_dataset (ds , path )
685
745
with pytest .raises (
686
746
ValueError , match = "The call_genotype_phased array is missing"
@@ -689,7 +749,7 @@ def test_missing_phase(self, tmp_path):
689
749
690
750
def test_phased (self , tmp_path ):
691
751
path = tmp_path / "data.zarr"
692
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
752
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
693
753
ds ["call_genotype_phased" ] = (
694
754
ds ["call_genotype" ].dims ,
695
755
np .ones (ds ["call_genotype" ].shape , dtype = bool ),
@@ -700,13 +760,13 @@ def test_phased(self, tmp_path):
700
760
def test_ploidy1_missing_phase (self , tmp_path ):
701
761
path = tmp_path / "data.zarr"
702
762
# Ploidy==1 is always ok
703
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
763
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
704
764
sgkit .save_dataset (ds , path )
705
765
tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
706
766
707
767
def test_ploidy1_unphased (self , tmp_path ):
708
768
path = tmp_path / "data.zarr"
709
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
769
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
710
770
ds ["call_genotype_phased" ] = (
711
771
ds ["call_genotype" ].dims ,
712
772
np .zeros (ds ["call_genotype" ].shape , dtype = bool ),
@@ -716,31 +776,54 @@ def test_ploidy1_unphased(self, tmp_path):
716
776
717
777
def test_duplicate_positions (self , tmp_path ):
718
778
path = tmp_path / "data.zarr"
719
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
779
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
720
780
ds ["variant_position" ][2 ] = ds ["variant_position" ][1 ]
721
781
sgkit .save_dataset (ds , path )
722
782
with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
723
783
tsinfer .VariantData (path , "variant_ancestral_allele" )
724
784
725
785
def test_bad_order_positions (self , tmp_path ):
726
786
path = tmp_path / "data.zarr"
727
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
787
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
728
788
ds ["variant_position" ][0 ] = ds ["variant_position" ][2 ] - 0.5
729
789
sgkit .save_dataset (ds , path )
730
790
with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
731
791
tsinfer .VariantData (path , "variant_ancestral_allele" )
732
792
793
+ def test_bad_ancestral_state (self , tmp_path ):
794
+ path = tmp_path / "data.zarr"
795
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
796
+ ancestral_state = ds ["variant_allele" ][:, 0 ].values .astype (str )
797
+ ancestral_state [1 ] = ""
798
+ sgkit .save_dataset (ds , path )
799
+ with pytest .raises (ValueError , match = "cannot contain empty strings" ):
800
+ tsinfer .VariantData (path , ancestral_state )
801
+
733
802
def test_empty_alleles_not_at_end (self , tmp_path ):
734
803
path = tmp_path / "data.zarr"
735
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
804
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
805
+ ds ["variant_allele" ] = (
806
+ ds ["variant_allele" ].dims ,
807
+ np .array ([["A" , "" , "C" ], ["A" , "C" , "" ], ["A" , "C" , "" ]], dtype = "S1" ),
808
+ )
809
+ sgkit .save_dataset (ds , path )
810
+ with pytest .raises (
811
+ ValueError , match = 'Bad alleles: fill value "" in middle of list'
812
+ ):
813
+ tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
814
+
815
+ def test_unique_alleles (self , tmp_path ):
816
+ path = tmp_path / "data.zarr"
817
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
736
818
ds ["variant_allele" ] = (
737
819
ds ["variant_allele" ].dims ,
738
- np .array ([["" , "A " , "C " ], ["A" , "C" , "" ], ["A" , "C " , "" ]], dtype = "S1" ),
820
+ np .array ([["A " , "C " , "T " ], ["A" , "C" , "" ], ["A" , "A " , "" ]], dtype = "S1" ),
739
821
)
740
822
sgkit .save_dataset (ds , path )
741
- vdata = tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
742
- with pytest .raises (ValueError , match = "Empty alleles must be at the end" ):
743
- tsinfer .infer (vdata )
823
+ with pytest .raises (
824
+ ValueError , match = "Duplicate allele values provided at site 2"
825
+ ):
826
+ tsinfer .VariantData (path , np .array (["A" , "A" , "A" ], dtype = "S1" ))
744
827
745
828
def test_unimplemented_from_tree_sequence (self ):
746
829
# NB we should reimplement something like this functionality.
0 commit comments