Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regions index #37

Merged
merged 10 commits into from
Jul 29, 2024
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import pytest

# rewrite asserts in assert_vcfs_close to give better failure messages
pytest.register_assert_rewrite("tests.utils")
8 changes: 5 additions & 3 deletions tests/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vcztools.regions import parse_region
from vcztools.regions import parse_region_string


@pytest.mark.parametrize(
Expand All @@ -14,5 +14,7 @@
("chr1:12-103", ("chr1", 12, 103)),
],
)
def test_parse_region(targets: str, expected: tuple[str, Optional[int], Optional[int]]):
assert parse_region(targets) == expected
def test_parse_region_string(
targets: str, expected: tuple[str, Optional[int], Optional[int]]
):
assert parse_region_string(targets) == expected
59 changes: 33 additions & 26 deletions tests/test_vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cyvcf2 import VCF
from numpy.testing import assert_array_equal

from vcztools.regions import create_index
from vcztools.vcf_writer import write_vcf

from .utils import assert_vcfs_close
Expand Down Expand Up @@ -56,32 +57,39 @@ def test_write_vcf(shared_datadir, tmp_path, output_is_path):

# fmt: off
@pytest.mark.parametrize(
("switch", "regions", "expected_chrom_pos"),
("regions", "targets", "expected_chrom_pos"),
[
("-t", "19", [("19", 111), ("19", 112)]),
("-t", "19:112", [("19", 112)]),
("-t", "20:1230236-", [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
("-t", "20:1230237-", [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
("-t", "20:1230238-", [("20", 1234567), ("20", 1235237)]),
("-t", "20:1230237-1235236", [("20", 1230237), ("20", 1234567)]),
("-t", "20:1230237-1235237", [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
("-t", "20:1230237-1235238", [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
("-t", "19,X", [("19", 111), ("19", 112), ("X", 10)]),
("-t", "X:11", []),
("-r", "19", [("19", 111), ("19", 112)]),
("-r", "19:112", [("19", 112)]),
("-r", "20:1230236-", [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
("-r", "20:1230237-", [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
("-r", "20:1230238-", [("20", 1234567), ("20", 1235237)]),
("-r", "20:1230237-1235236", [("20", 1230237), ("20", 1234567)]),
("-r", "20:1230237-1235237", [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
("-r", "20:1230237-1235238", [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
("-r", "19,X", [("19", 111), ("19", 112), ("X", 10)]),
("-r", "X:11", [("X", 10)]), # note differs from -t
# regions only
("19", None, [("19", 111), ("19", 112)]),
("19:112", None, [("19", 112)]),
("20:1230236-", None, [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
("20:1230237-", None, [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
("20:1230238-", None, [("20", 1234567), ("20", 1235237)]),
("20:1230237-1235236", None, [("20", 1230237), ("20", 1234567)]),
("20:1230237-1235237", None, [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
("20:1230237-1235238", None, [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
("19,X", None, [("19", 111), ("19", 112), ("X", 10)]),
("X:11", None, [("X", 10)]), # note differs from targets

# targets only
(None, "19", [("19", 111), ("19", 112)]),
(None, "19:112", [("19", 112)]),
(None, "20:1230236-", [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
(None, "20:1230237-", [("20", 1230237), ("20", 1234567), ("20", 1235237)]),
(None, "20:1230238-", [("20", 1234567), ("20", 1235237)]),
(None, "20:1230237-1235236", [("20", 1230237), ("20", 1234567)]),
(None, "20:1230237-1235237", [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
(None, "20:1230237-1235238", [("20", 1230237), ("20", 1234567), ("20", 1235237)]), # noqa: E501
(None, "19,X", [("19", 111), ("19", 112), ("X", 10)]),
(None, "X:11", []),
(None, "^19,20:1-1234567", [("20", 1235237), ("X", 10)]), # complement

# regions and targets
("20", "^20:1110696-", [("20", 14370), ("20", 17330)])
]
)
# fmt: on
def test_write_vcf__regions(shared_datadir, tmp_path, switch, regions,
def test_write_vcf__regions(shared_datadir, tmp_path, regions, targets,
expected_chrom_pos):
path = shared_datadir / "vcf" / "sample.vcf.gz"
intermediate_icf = tmp_path.joinpath("intermediate.icf")
Expand All @@ -91,11 +99,10 @@ def test_write_vcf__regions(shared_datadir, tmp_path, switch, regions,
vcf2zarr.convert(
[path], intermediate_vcz, icf_path=intermediate_icf, worker_processes=0
)
create_index(intermediate_vcz)

if switch == "-t":
write_vcf(intermediate_vcz, output, variant_targets=regions)
elif switch == "-r":
write_vcf(intermediate_vcz, output, variant_regions=regions)
write_vcf(intermediate_vcz, output, variant_regions=regions,
variant_targets=targets)

v = VCF(output)
variants = list(v)
Expand Down
9 changes: 8 additions & 1 deletion vcztools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import click

from . import vcf_writer
from . import regions, vcf_writer


class NaturalOrderGroup(click.Group):
Expand All @@ -14,6 +14,12 @@ def list_commands(self, ctx):
return self.commands.keys()


@click.command
@click.argument("path", type=click.Path())
def index(path):
regions.create_index(path)


@click.command
@click.argument("path", type=click.Path())
@click.option(
Expand Down Expand Up @@ -41,4 +47,5 @@ def vcztools_main():
pass


vcztools_main.add_command(index)
vcztools_main.add_command(view)
1 change: 1 addition & 0 deletions vcztools/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
RESERVED_VARIABLE_NAMES = [
"variant_contig",
"variant_position",
"variant_position_end",
"variant_id",
"variant_id_mask",
"variant_allele",
Expand Down
214 changes: 184 additions & 30 deletions vcztools/regions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,58 @@
import re
from typing import Any, Optional

import numcodecs
import numpy as np
import pandas as pd
import pyranges
import zarr
from pyranges import PyRanges


def parse_region(region: str) -> tuple[str, Optional[int], Optional[int]]:
def create_index(vcz) -> None:
"""Create an index to support efficient region queries."""

root = zarr.open(vcz, mode="r+")

contig = root["variant_contig"]
pos = root["variant_position"]
end = root["variant_position_end"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant_length might be an easier name to understand and computing the end is simple from that. But that's a discussion for elsewhere


assert contig.cdata_shape == pos.cdata_shape

index = []

for v_chunk in range(pos.cdata_shape[0]):
c = contig.blocks[v_chunk]
p = pos.blocks[v_chunk]
e = end.blocks[v_chunk]

# create a row for each contig in the chunk
d = np.diff(c, append=-1)
c_start_idx = 0
for c_end_idx in np.nonzero(d)[0]:
assert c[c_start_idx] == c[c_end_idx]
index.append(
(
v_chunk, # chunk index
c[c_start_idx], # contig ID
p[c_start_idx], # start
p[c_end_idx], # end
np.max(e[c_start_idx : c_end_idx + 1]), # max end
c_end_idx - c_start_idx + 1, # num records
)
)
c_start_idx = c_end_idx + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks just right! 🤩


index = np.array(index, dtype=np.int32)
root.array(
"region_index",
data=index,
compressor=numcodecs.Blosc("zstd", clevel=9, shuffle=0),
overwrite=True,
)


def parse_region_string(region: str) -> tuple[str, Optional[int], Optional[int]]:
"""Return the contig, start position and end position from a region string."""
if re.search(r":\d+-\d*$", region):
contig, start_end = region.rsplit(":", 1)
Expand All @@ -20,32 +66,10 @@ def parse_region(region: str) -> tuple[str, Optional[int], Optional[int]]:
return contig, None, None


def parse_regions(targets: str) -> list[tuple[str, Optional[int], Optional[int]]]:
return [parse_region(region) for region in targets.split(",")]


def parse_targets(targets: str) -> list[tuple[str, Optional[int], Optional[int]]]:
return [parse_region(region) for region in targets.split(",")]


def regions_to_selection(
all_contigs: list[str],
variant_contig: Any,
variant_position: Any,
variant_length: Any,
regions: list[tuple[str, Optional[int], Optional[int]]],
):
# subtract 1 from start coordinate to convert intervals
# from VCF (1-based, fully-closed) to Python (0-based, half-open)
variant_start = variant_position - 1
variant_end = variant_start + variant_length
df = pd.DataFrame(
{"Chromosome": variant_contig, "Start": variant_start, "End": variant_end}
)

# save original index as column so we can retrieve it after finding overlap
df["index"] = df.index
variants = pyranges.PyRanges(df)
def regions_to_pyranges(
regions: list[tuple[str, Optional[int], Optional[int]]], all_contigs: list[str]
) -> PyRanges:
"""Convert region tuples to a PyRanges object."""

chromosomes = []
starts = []
Expand All @@ -63,9 +87,139 @@ def regions_to_selection(
starts.append(start)
ends.append(end)

query = pyranges.PyRanges(chromosomes=chromosomes, starts=starts, ends=ends)
return PyRanges(chromosomes=chromosomes, starts=starts, ends=ends)


def parse_regions(regions: Optional[str], all_contigs: list[str]) -> Optional[PyRanges]:
"""Return a PyRanges object from a comma-separated set of region strings."""
if regions is None:
return None
return regions_to_pyranges(
[parse_region_string(region) for region in regions.split(",")], all_contigs
)


def parse_targets(
targets: Optional[str], all_contigs: list[str]
) -> tuple[Optional[PyRanges], bool]:
"""Return a PyRanges object from a comma-separated set of region strings,
optionally preceeded by a ^ character to indicate complement."""
if targets is None:
return None, False
complement = targets.startswith("^")
return parse_regions(
targets[1:] if complement else targets, all_contigs
), complement


def regions_to_chunk_indexes(
regions: Optional[PyRanges],
targets: Optional[PyRanges],
complement: bool,
regions_index: Any,
):
"""Return chunks indexes that overlap the given regions or targets.

If both regions and targets are specified then only regions are used
to find overlapping chunks (since targets are used later to refine).

If only targets are specified then they are used to find overlapping chunks,
taking into account the complement flag.
"""

# Create pyranges for chunks using the region index.
# For regions use max end position, for targets just end position
chunk_index = regions_index[:, 0]
contig_id = regions_index[:, 1]
start_position = regions_index[:, 2]
end_position = regions_index[:, 3]
max_end_position = regions_index[:, 4]
df = pd.DataFrame(
{
"chunk_index": chunk_index,
"Chromosome": contig_id,
"Start": start_position,
"End": max_end_position if regions is not None else end_position,
}
)
chunk_regions = PyRanges(df)

if regions is not None:
overlap = chunk_regions.overlap(regions)
elif complement:
overlap = chunk_regions.subtract(targets)
else:
overlap = chunk_regions.overlap(targets)
if overlap.empty:
return np.empty((0,), dtype=np.int64)
chunk_indexes = overlap.df["chunk_index"].to_numpy()
chunk_indexes = np.unique(chunk_indexes)
return chunk_indexes


def regions_to_selection(
regions: Optional[PyRanges],
targets: Optional[PyRanges],
complement: bool,
variant_contig: Any,
variant_position: Any,
variant_end: Any,
):
"""Return a variant selection that corresponds to the given regions and targets.

If both regions and targets are specified then they are both used to find
overlapping variants.
"""

# subtract 1 from start coordinate to convert intervals
# from VCF (1-based, fully-closed) to Python (0-based, half-open)
variant_start = variant_position - 1

if regions is not None:
df = pd.DataFrame(
{"Chromosome": variant_contig, "Start": variant_start, "End": variant_end}
)
# save original index as column so we can retrieve it after finding overlap
df["index"] = df.index
variant_regions = PyRanges(df)
else:
variant_regions = None

if targets is not None:
targets_variant_end = variant_position # length 1
df = pd.DataFrame(
{
"Chromosome": variant_contig,
"Start": variant_start,
"End": targets_variant_end,
}
)
# save original index as column so we can retrieve it after finding overlap
df["index"] = df.index
variant_targets = PyRanges(df)
else:
variant_targets = None

if variant_regions is not None:
regions_overlap = variant_regions.overlap(regions)
else:
regions_overlap = None

if variant_targets is not None:
if complement:
targets_overlap = variant_targets.subtract(targets)
else:
targets_overlap = variant_targets.overlap(targets)
else:
targets_overlap = None

if regions_overlap is not None and targets_overlap is not None:
overlap = regions_overlap.overlap(targets_overlap)
elif regions_overlap is not None:
overlap = regions_overlap
else:
overlap = targets_overlap

overlap = variants.overlap(query)
if overlap.empty:
return np.empty((0,), dtype=np.int64)
return overlap.df["index"].to_numpy()
Loading
Loading