Skip to content

Commit 15c500a

Browse files
committed
Properly treat blank ancestral allele, and set "N" as the default "unknown" state
Also document the class
1 parent 1d04fb8 commit 15c500a

File tree

6 files changed

+243
-79
lines changed

6 files changed

+243
-79
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## [0.4.0a3] - ****-**-**
4+
5+
**Fixes**
6+
7+
- Properly account for "N" as an unknown ancestral state, and ban "" from being
8+
set as an ancestral state ({pr}`963`, {user}`hyanwong`))
9+
310
## [0.4.0a2] - 2024-09-06
411

512
2nd Alpha release of tsinfer 0.4.0

docs/usage.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ for sample in range(ds['call_genotype'].shape[1]):
6060

6161
We wish to infer a genealogy that could have given rise to this data set. To run _tsinfer_
6262
we wrap the .vcz file in a `tsinfer.VariantData` object. This requires an
63-
*ancestral allele* to be specified for each site; there are
63+
*ancestral state* to be specified for each site; there are
6464
many methods for calculating these: details are outside the scope of this manual, but we
6565
have started a [discussion topic](https://github.com/tskit-dev/tsinfer/discussions/523)
6666
on this issue to provide some recommendations.
@@ -76,11 +76,11 @@ and not used for inference (with a warning given).
7676
import tsinfer
7777
7878
# For this example take the REF allele (index 0) as ancestral
79-
ancestral_allele = ds['variant_allele'][:,0].astype(str)
79+
ancestral_state = ds['variant_allele'][:,0].astype(str)
8080
# This is just a numpy array, set the last site to an unknown value, for demo purposes
81-
ancestral_allele[-1] = "."
81+
ancestral_state[-1] = "."
8282
83-
vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_allele)
83+
vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_state)
8484
```
8585

8686
The `VariantData` object is a lightweight wrapper around the .vcz file.
@@ -127,7 +127,7 @@ site_mask[ds.variant_position[:] >= 6] = True
127127
128128
smaller_vdata = tsinfer.VariantData(
129129
"_static/example_data.vcz",
130-
ancestral_allele=ancestral_allele[site_mask == False],
130+
ancestral_state=ancestral_state[site_mask == False],
131131
site_mask=site_mask,
132132
)
133133
print(f"The `smaller_vdata` object returns data for only {smaller_vdata.num_sites} sites")
@@ -351,8 +351,8 @@ Once we have our `.vcz` file created, running the inference is straightforward.
351351

352352
```{code-cell} ipython3
353353
# Infer & save a ts from the notebook simulation.
354-
ancestral_alleles = np.load(f"{name}-AA.npy")
355-
vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_alleles)
354+
ancestral_states = np.load(f"{name}-AA.npy")
355+
vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_states)
356356
tsinfer.infer(vdata, progress_monitor=True, num_threads=4).dump(name + ".trees")
357357
```
358358

@@ -477,12 +477,12 @@ vcf_location = "_static/P_dom_chr24_phased.vcf.gz"
477477
```
478478

479479
This creates the `sparrows.vcz` datastore, which we open using `tsinfer.VariantData`.
480-
The original VCF had ancestral alleles specified in the `AA` INFO field, so we can
481-
simply provide the string `"variant_AA"` as the ancestral_allele parameter.
480+
The original VCF had the ancestral allelic state specified in the `AA` INFO field,
481+
so we can simply provide the string `"variant_AA"` as the ancestral_state parameter.
482482

483483
```{code-cell} ipython3
484-
# Do the inference: this VCF has ancestral alleles in the AA field
485-
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA")
484+
# Do the inference: this VCF has ancestral states in the AA field
485+
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA")
486486
ts = tsinfer.infer(vdata)
487487
print(
488488
"Inferred tree sequence: {} trees over {} Mb ({} edges)".format(
@@ -534,7 +534,7 @@ Now when we carry out the inference, we get a tree sequence in which the nodes a
534534
correctly assigned to named populations
535535

536536
```{code-cell} ipython3
537-
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA")
537+
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA")
538538
sparrow_ts = tsinfer.infer(vdata)
539539
540540
for sample_node_id in sparrow_ts.samples():

tests/test_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15321532
mat_wd = tsinfer.match_samples_batch_init(
15331533
work_dir=tmpdir / "working_mat",
15341534
sample_data_path=mat_sd.path,
1535-
ancestral_allele="variant_ancestral_allele",
1535+
ancestral_state="variant_ancestral_allele",
15361536
ancestor_ts_path=tmpdir / "mat_anc.trees",
15371537
min_work_per_job=1,
15381538
max_num_partitions=10,
@@ -1547,7 +1547,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15471547
mask_wd = tsinfer.match_samples_batch_init(
15481548
work_dir=tmpdir / "working_mask",
15491549
sample_data_path=mask_sd.path,
1550-
ancestral_allele="variant_ancestral_allele",
1550+
ancestral_state="variant_ancestral_allele",
15511551
ancestor_ts_path=tmpdir / "mask_anc.trees",
15521552
min_work_per_job=1,
15531553
max_num_partitions=10,

tests/test_variantdata.py

Lines changed: 100 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
Tests for the data files.
2121
"""
2222
import json
23+
import logging
2324
import sys
2425
import tempfile
26+
import warnings
2527

2628
import msprime
2729
import numcodecs
@@ -627,14 +629,12 @@ def test_missing_ancestral_allele(tmp_path):
627629

628630

629631
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
630-
def test_ancestral_missingness(tmp_path):
632+
def test_deliberate_ancestral_missingness(tmp_path):
631633
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
632634
ds = sgkit.load_dataset(zarr_path)
633635
ancestral_allele = ds.variant_ancestral_allele.values
634636
ancestral_allele[0] = "N"
635-
ancestral_allele[11] = "-"
636-
ancestral_allele[12] = "💩"
637-
ancestral_allele[15] = "💩"
637+
ancestral_allele[1] = "n"
638638
ds = ds.drop_vars(["variant_ancestral_allele"])
639639
sgkit.save_dataset(ds, str(zarr_path) + ".tmp")
640640
tsutil.add_array_to_dataset(
@@ -644,15 +644,56 @@ def test_ancestral_missingness(tmp_path):
644644
["variants"],
645645
)
646646
ds = sgkit.load_dataset(str(zarr_path) + ".tmp")
647+
with warnings.catch_warnings():
648+
warnings.simplefilter("error") # No warning raised if AA deliberately missing
649+
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
650+
inf_ts = tsinfer.infer(sd)
651+
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
652+
if i in [0, 1]:
653+
assert inf_var.site.metadata == {"inference_type": "parsimony"}
654+
else:
655+
assert inf_var.site.ancestral_state == var.site.ancestral_state
656+
657+
658+
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
659+
def test_ancestral_missing_warning(tmp_path):
660+
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
661+
ds = sgkit.load_dataset(zarr_path)
662+
anc_state = ds.variant_ancestral_allele.values
663+
anc_state[0] = "N"
664+
anc_state[11] = "-"
665+
anc_state[12] = "💩"
666+
anc_state[15] = "💩"
647667
with pytest.warns(
648668
UserWarning,
649669
match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2",
650670
):
651-
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
652-
inf_ts = tsinfer.infer(sd)
671+
vdata = tsinfer.VariantData(zarr_path, anc_state)
672+
inf_ts = tsinfer.infer(vdata)
653673
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
654674
if i in [0, 11, 12, 15]:
655675
assert inf_var.site.metadata == {"inference_type": "parsimony"}
676+
assert inf_var.site.ancestral_state in var.site.alleles
677+
else:
678+
assert inf_var.site.ancestral_state == var.site.ancestral_state
679+
680+
681+
def test_ancestral_missing_info(tmp_path, caplog):
682+
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
683+
ds = sgkit.load_dataset(zarr_path)
684+
anc_state = ds.variant_ancestral_allele.values
685+
anc_state[0] = "N"
686+
anc_state[11] = "N"
687+
anc_state[12] = "n"
688+
anc_state[15] = "n"
689+
with caplog.at_level(logging.INFO):
690+
vdata = tsinfer.VariantData(zarr_path, anc_state)
691+
assert f"4 sites ({4/ts.num_sites * 100 :.2f}%) were deliberately " in caplog.text
692+
inf_ts = tsinfer.infer(vdata)
693+
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
694+
if i in [0, 11, 12, 15]:
695+
assert inf_var.site.metadata == {"inference_type": "parsimony"}
696+
assert inf_var.site.ancestral_state in var.site.alleles
656697
else:
657698
assert inf_var.site.ancestral_state == var.site.ancestral_state
658699

@@ -670,6 +711,25 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
670711

671712

672713
class TestVariantDataErrors:
714+
@staticmethod
715+
def simulate_genotype_call_dataset(*args, **kwargs):
716+
# roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
717+
# duplicate alleles are created. Doesn't need to be efficient: just for testing
718+
if "seed" not in kwargs:
719+
kwargs["seed"] = 123
720+
ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs)
721+
variant_alleles = ds["variant_allele"].values
722+
allowed_alleles = np.array(
723+
["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype
724+
)
725+
for row in range(len(variant_alleles)):
726+
alleles = variant_alleles[row]
727+
if len(set(alleles)) != len(alleles):
728+
# Just use a set that we know is unique
729+
variant_alleles[row] = allowed_alleles[0 : len(alleles)]
730+
ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles
731+
return ds
732+
673733
def test_bad_zarr_spec(self):
674734
ds = zarr.group()
675735
ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
@@ -680,7 +740,7 @@ def test_bad_zarr_spec(self):
680740

681741
def test_missing_phase(self, tmp_path):
682742
path = tmp_path / "data.zarr"
683-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
743+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
684744
sgkit.save_dataset(ds, path)
685745
with pytest.raises(
686746
ValueError, match="The call_genotype_phased array is missing"
@@ -689,7 +749,7 @@ def test_missing_phase(self, tmp_path):
689749

690750
def test_phased(self, tmp_path):
691751
path = tmp_path / "data.zarr"
692-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
752+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
693753
ds["call_genotype_phased"] = (
694754
ds["call_genotype"].dims,
695755
np.ones(ds["call_genotype"].shape, dtype=bool),
@@ -700,13 +760,13 @@ def test_phased(self, tmp_path):
700760
def test_ploidy1_missing_phase(self, tmp_path):
701761
path = tmp_path / "data.zarr"
702762
# Ploidy==1 is always ok
703-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
763+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
704764
sgkit.save_dataset(ds, path)
705765
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
706766

707767
def test_ploidy1_unphased(self, tmp_path):
708768
path = tmp_path / "data.zarr"
709-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
769+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
710770
ds["call_genotype_phased"] = (
711771
ds["call_genotype"].dims,
712772
np.zeros(ds["call_genotype"].shape, dtype=bool),
@@ -716,31 +776,54 @@ def test_ploidy1_unphased(self, tmp_path):
716776

717777
def test_duplicate_positions(self, tmp_path):
718778
path = tmp_path / "data.zarr"
719-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
779+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
720780
ds["variant_position"][2] = ds["variant_position"][1]
721781
sgkit.save_dataset(ds, path)
722782
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
723783
tsinfer.VariantData(path, "variant_ancestral_allele")
724784

725785
def test_bad_order_positions(self, tmp_path):
726786
path = tmp_path / "data.zarr"
727-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
787+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
728788
ds["variant_position"][0] = ds["variant_position"][2] - 0.5
729789
sgkit.save_dataset(ds, path)
730790
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
731791
tsinfer.VariantData(path, "variant_ancestral_allele")
732792

793+
def test_bad_ancestral_state(self, tmp_path):
794+
path = tmp_path / "data.zarr"
795+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
796+
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
797+
ancestral_state[1] = ""
798+
sgkit.save_dataset(ds, path)
799+
with pytest.raises(ValueError, match="cannot contain empty strings"):
800+
tsinfer.VariantData(path, ancestral_state)
801+
733802
def test_empty_alleles_not_at_end(self, tmp_path):
734803
path = tmp_path / "data.zarr"
735-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
804+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
805+
ds["variant_allele"] = (
806+
ds["variant_allele"].dims,
807+
np.array([["A", "", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
808+
)
809+
sgkit.save_dataset(ds, path)
810+
with pytest.raises(
811+
ValueError, match='Bad alleles: fill value "" in middle of list'
812+
):
813+
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
814+
815+
def test_unique_alleles(self, tmp_path):
816+
path = tmp_path / "data.zarr"
817+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
736818
ds["variant_allele"] = (
737819
ds["variant_allele"].dims,
738-
np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
820+
np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"),
739821
)
740822
sgkit.save_dataset(ds, path)
741-
vdata = tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
742-
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
743-
tsinfer.infer(vdata)
823+
with pytest.raises(
824+
ValueError, match="Duplicate allele values provided at site 2"
825+
):
826+
tsinfer.VariantData(path, np.array(["A", "A", "A"], dtype="S1"))
744827

745828
def test_unimplemented_from_tree_sequence(self):
746829
# NB we should reimplement something like this functionality.

0 commit comments

Comments
 (0)