Skip to content

Commit

Permalink
bgen_to_zarr implementation sgkit-dev#16
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-czech committed Sep 3, 2020
1 parent 0153804 commit d132756
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 26 deletions.
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ ignore =
[isort]
default_section = THIRDPARTY
known_first_party = sgkit
known_third_party = bgen_reader,dask,numpy,pytest,setuptools,xarray
known_third_party = bgen_reader,dask,numpy,pytest,setuptools,xarray,zarr
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 0
Expand All @@ -71,5 +71,7 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-sgkit.*]
ignore_missing_imports = True
[mypy-zarr.*]
ignore_missing_imports = True
[mypy-sgkit_bgen.tests.*]
disallow_untyped_defs = False
4 changes: 2 additions & 2 deletions sgkit_bgen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .bgen_reader import read_bgen # noqa: F401
from .bgen_reader import read_bgen, rechunk_from_zarr, rechunk_to_zarr # noqa: F401

__all__ = ["read_bgen"]
__all__ = ["read_bgen", "rechunk_from_zarr", "rechunk_to_zarr"]
150 changes: 133 additions & 17 deletions sgkit_bgen/bgen_reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""BGEN reader implementation (using bgen_reader)"""
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Hashable, MutableMapping, Optional, Tuple, Union

import dask.array as da
import dask.dataframe as dd
import numpy as np
import xarray as xr
import zarr
from bgen_reader._bgen_file import bgen_file
from bgen_reader._bgen_metafile import bgen_metafile
from bgen_reader._metafile import create_metafile
from bgen_reader._reader import infer_metafile_filepath
from bgen_reader._samples import generate_samples, read_samples_file
from xarray import Dataset
from xarray.backends.zarr import ZarrStore

from sgkit import create_genotype_dosage_dataset
from sgkit.typing import ArrayLike
Expand Down Expand Up @@ -38,6 +41,13 @@ def _to_dict(df: dd.DataFrame, dtype: Any = None) -> Dict[str, da.Array]:
VARIANT_DF_DTYPE = dict([(f[0], f[1]) for f in VARIANT_FIELDS])
VARIANT_ARRAY_DTYPE = dict([(f[0], f[2]) for f in VARIANT_FIELDS])

GT_DATA_VARS = [
"call_genotype_probability",
"call_genotype_probability_mask",
"call_dosage",
"call_dosage_mask",
]


class BgenReader:

Expand Down Expand Up @@ -79,15 +89,7 @@ def split(allele_row: np.ndarray) -> np.ndarray:

return np.apply_along_axis(split, 1, alleles[:, np.newaxis])

variant_alleles = variant_arrs["allele_ids"].map_blocks(split_alleles)

def max_str_len(arr: ArrayLike) -> Any:
return arr.map_blocks(
lambda s: np.char.str_len(s.astype(str)), dtype=np.int8
).max()

max_allele_length = max(max_str_len(variant_alleles).compute())
self.variant_alleles = variant_alleles.astype(f"S{max_allele_length}")
self.variant_alleles = variant_arrs["allele_ids"].map_blocks(split_alleles)

with bgen_file(self.path) as bgen:
sample_path = self.path.with_suffix(".sample")
Expand Down Expand Up @@ -172,6 +174,7 @@ def read_bgen(
chunks: Union[str, int, Tuple[int, ...]] = "auto",
lock: bool = False,
persist: bool = True,
dtype: Any = "float32",
) -> Dataset:
"""Read BGEN dataset.
Expand All @@ -194,23 +197,23 @@ def read_bgen(
memory, by default True. This is an important performance
consideration as the metadata file for this data will
be read multiple times when False.
dtype : Any
Genotype probability array data type, by default float32.
Warnings
--------
Only bi-allelic, diploid BGEN files are currently supported.
"""

bgen_reader = BgenReader(path, persist)
bgen_reader = BgenReader(path, persist, dtype=dtype)

variant_contig, variant_contig_names = encode_array(bgen_reader.contig.compute())
variant_contig_names = list(variant_contig_names)
variant_contig = variant_contig.astype("int16")

variant_position = np.array(bgen_reader.pos, dtype=int)
variant_alleles = np.array(bgen_reader.variant_alleles, dtype="S1")
variant_id = np.array(bgen_reader.variant_id, dtype=str)

sample_id = np.array(bgen_reader.sample_id, dtype=str)
variant_position = np.asarray(bgen_reader.pos, dtype=int)
variant_alleles = np.asarray(bgen_reader.variant_alleles, dtype="S")
variant_id = np.asarray(bgen_reader.variant_id, dtype=str)
sample_id = np.asarray(bgen_reader.sample_id, dtype=str)

call_genotype_probability = da.from_array(
bgen_reader,
Expand All @@ -234,3 +237,116 @@ def read_bgen(
)

return ds


def encode_variables(
ds: Dataset,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[Any] = "uint8",
) -> Dict[Hashable, Dict[str, Any]]:
encoding = {}
for v in ds:
e = {}
if compressor is not None:
e.update({"compressor": compressor})
if probability_dtype is not None and v == "call_genotype_probability":
dtype = np.dtype(probability_dtype)
# Xarray will decode into float32 so any int greater than
# 16 bits will cause overflow/underflow
# See https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
# *bits precision column for single precision floats
if dtype not in [np.uint8, np.uint16]:
raise ValueError(
"Probability integer dtype invalid, must "
f"be uint8 or uint16 not {probability_dtype}"
)
divisor = np.iinfo(dtype).max - 1
e.update(
{
"dtype": probability_dtype,
"add_offset": -1.0 / divisor,
"scale_factor": 1.0 / divisor,
"_FillValue": 0,
}
)
if e:
encoding[v] = e
return encoding


def pack_variables(ds: Dataset) -> Dataset:
# Remove dosage as it is unnecessary and should be redefined
# based on encoded probabilities later (w/ reduced precision)
ds = ds.drop_vars(["call_dosage", "call_dosage_mask"], errors="ignore")

# Remove homozygous reference GP and redefine mask
gp = ds["call_genotype_probability"][..., 1:]
gp_mask = ds["call_genotype_probability_mask"].any(dim="genotypes")
ds = ds.drop_vars(["call_genotype_probability", "call_genotype_probability_mask"])
ds = ds.assign(call_genotype_probability=gp, call_genotype_probability_mask=gp_mask)
return ds


def unpack_variables(ds: Dataset, dtype: Any = "float32") -> Dataset:
# Restore homozygous reference GP
gp = ds["call_genotype_probability"].astype(dtype)
if gp.sizes["genotypes"] != 2:
raise ValueError(
"Expecting variable 'call_genotype_probability' to have genotypes "
f"dimension of size 2 (received sizes = {dict(gp.sizes)})"
)
ds = ds.drop_vars("call_genotype_probability")
ds["call_genotype_probability"] = xr.concat( # type: ignore[no-untyped-call]
[1 - gp.sum(dim="genotypes", skipna=False), gp], dim="genotypes"
)

# Restore dosage
ds["call_dosage"] = gp[..., 0] + 2 * gp[..., 1]
ds["call_dosage_mask"] = ds["call_genotype_probability_mask"]
ds["call_genotype_probability_mask"] = ds[
"call_genotype_probability_mask"
].broadcast_like(ds["call_genotype_probability"])
return ds


def rechunk_to_zarr(
ds: Dataset,
store: Union[PathType, MutableMapping[str, bytes]],
*,
mode: str = "w",
chunk_length: int = 10_000,
chunk_width: int = 10_000,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[Any] = "uint8",
pack: bool = True,
compute: bool = True,
) -> ZarrStore:
if pack:
ds = pack_variables(ds)
for v in set(GT_DATA_VARS) & set(ds):
chunk_size = da.asarray(ds[v]).chunksize[0]
if chunk_length % chunk_size != 0:
raise ValueError(
f"Chunk size in variant dimension for variable '{v}' ({chunk_size}) "
f"must evenly divide target chunk size {chunk_length}"
)
ds[v] = ds[v].chunk(chunks=dict(samples=chunk_width)) # type: ignore[dict-item]
encoding = encode_variables(
ds, compressor=compressor, probability_dtype=probability_dtype
)
return ds.to_zarr(store, mode=mode, encoding=encoding or None, compute=compute) # type: ignore[arg-type]


def rechunk_from_zarr(
store: Union[PathType, MutableMapping[str, bytes]],
chunk_length: int = 10_000,
chunk_width: int = 10_000,
mask_and_scale: bool = True,
) -> Dataset:
# Always use concat_characters=False to avoid https://github.com/pydata/xarray/issues/4405
ds = xr.open_zarr(store, mask_and_scale=mask_and_scale, concat_characters=False) # type: ignore[no-untyped-call]
for v in set(GT_DATA_VARS) & set(ds):
ds[v] = ds[v].chunk(chunks=dict(variants=chunk_length, samples=chunk_width))
# Workaround for https://github.com/pydata/xarray/issues/4380
del ds[v].encoding["chunks"]
return ds # type: ignore[no-any-return]
2 changes: 2 additions & 0 deletions sgkit_bgen/tests/data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.metadata2.mmm
*.metafile
Loading

0 comments on commit d132756

Please sign in to comment.