Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a region index in 'vcf2zarr encode' #291

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions bio2zarr/vcf2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,63 @@ def finalise(self, show_progress=False):
logger.info("Consolidating Zarr metadata")
zarr.consolidate_metadata(self.path)

#######################
# index
#######################

def create_index(self):
"""Create an index to support efficient region queries."""

store = zarr.DirectoryStore(self.path)
root = zarr.open_group(store=store, mode="r+")

contig = root["variant_contig"]
pos = root["variant_position"]
length = root["variant_length"]

assert contig.cdata_shape == pos.cdata_shape

index = []

logger.info("Creating region index")
for v_chunk in range(pos.cdata_shape[0]):
c = contig.blocks[v_chunk]
p = pos.blocks[v_chunk]
e = p + length.blocks[v_chunk] - 1

# create a row for each contig in the chunk
d = np.diff(c, append=-1)
c_start_idx = 0
for c_end_idx in np.nonzero(d)[0]:
assert c[c_start_idx] == c[c_end_idx]
index.append(
(
v_chunk, # chunk index
c[c_start_idx], # contig ID
p[c_start_idx], # start
p[c_end_idx], # end
np.max(e[c_start_idx : c_end_idx + 1]), # max end
c_end_idx - c_start_idx + 1, # num records
)
)
c_start_idx = c_end_idx + 1

index = np.array(index, dtype=np.int32)
array = root.array(
"region_index",
data=index,
shape=index.shape,
dtype=index.dtype,
compressor=numcodecs.Blosc("zstd", clevel=9, shuffle=0),
)
array.attrs["_ARRAY_DIMENSIONS"] = [
"region_index_values",
"region_index_fields",
]

logger.info("Consolidating Zarr metadata")
zarr.consolidate_metadata(self.path)

######################
# encode_all_partitions
######################
Expand Down Expand Up @@ -1004,6 +1061,7 @@ def encode(
max_memory=max_memory,
)
vzw.finalise(show_progress)
vzw.create_index()


def encode_init(
Expand Down
22 changes: 16 additions & 6 deletions tests/test_vcf_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from bio2zarr import constants, provenance, vcf2zarr


def assert_dataset_equal(ds1, ds2, drop_vars=None):
if drop_vars is None:
xt.assert_equal(ds1, ds2)
else:
xt.assert_equal(ds1.drop_vars(drop_vars), ds2.drop_vars(drop_vars))


class TestSmallExample:
data_path = "tests/data/vcf/sample.vcf.gz"

Expand Down Expand Up @@ -273,7 +280,7 @@ def test_chunk_size(
ds2 = sg.load_dataset(out)
# print(ds2.call_genotype.values)
# print(ds.call_genotype.values)
xt.assert_equal(ds, ds2)
assert_dataset_equal(ds, ds2, drop_vars=["region_index"])
assert ds2.call_DP.chunks == (y_chunks, x_chunks)
assert ds2.call_GQ.chunks == (y_chunks, x_chunks)
assert ds2.call_HQ.chunks == (y_chunks, x_chunks, (2,))
Expand Down Expand Up @@ -341,8 +348,10 @@ def test_max_variant_chunks(
max_variant_chunks=max_variant_chunks,
)
ds2 = sg.load_dataset(out)
xt.assert_equal(
ds.isel(variants=slice(None, variants_chunk_size * max_variant_chunks)), ds2
assert_dataset_equal(
ds.isel(variants=slice(None, variants_chunk_size * max_variant_chunks)),
ds2,
drop_vars=["region_index"],
)

@pytest.mark.parametrize("worker_processes", [0, 1, 2])
Expand All @@ -355,7 +364,7 @@ def test_worker_processes(self, ds, tmp_path, worker_processes):
worker_processes=worker_processes,
)
ds2 = sg.load_dataset(out)
xt.assert_equal(ds, ds2)
assert_dataset_equal(ds, ds2, drop_vars=["region_index"])

def test_inspect(self, tmp_path):
# TODO pretty weak test, we should be doing this better somewhere else
Expand Down Expand Up @@ -391,8 +400,8 @@ def test_missing_contig_vcf(self, ds, tmp_path, path):
ds_c1 = ds.isel(variants=ds["variant_contig"].values == id1)
id2 = contig_id_2.index(contig)
ds_c2 = ds2.isel(variants=ds2["variant_contig"].values == id2)
drop_vars = ["contig_id", "variant_contig"]
xt.assert_equal(ds_c1.drop_vars(drop_vars), ds_c2.drop_vars(drop_vars))
drop_vars = ["contig_id", "variant_contig", "region_index"]
assert_dataset_equal(ds_c1, ds_c2, drop_vars=drop_vars)

def test_vcf_dimensions(self, ds):
assert ds.call_genotype.dims == ("variants", "samples", "ploidy")
Expand Down Expand Up @@ -854,6 +863,7 @@ def test_info_fields(self, ds):
"contig_id",
"contig_length",
"filter_id",
"region_index",
"sample_id",
]
assert sorted(list(ds)) == sorted(info_vars + standard_vars)
Expand Down