Skip to content

Commit f5f12dc

Browse files
committed
Bgen to zarr implementation sgkit-dev#16
1 parent 4ee4ee2 commit f5f12dc

File tree

4 files changed

+133
-95
lines changed

4 files changed

+133
-95
lines changed

setup.cfg

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ignore =
5454
[isort]
5555
default_section = THIRDPARTY
5656
known_first_party = sgkit
57-
known_third_party = bgen_reader,dask,numpy,pytest,setuptools,xarray,zarr
57+
known_third_party = bgen_reader,dask,numpy,pytest,rechunker,setuptools,xarray,zarr
5858
multi_line_output = 3
5959
include_trailing_comma = True
6060
force_grid_wrap = 0
@@ -67,6 +67,8 @@ ignore_missing_imports = True
6767
ignore_missing_imports = True
6868
[mypy-setuptools.*]
6969
ignore_missing_imports = True
70+
[mypy-rechunker.*]
71+
ignore_missing_imports = True
7072
[mypy-bgen_reader.*]
7173
ignore_missing_imports = True
7274
[mypy-sgkit.*]
@@ -75,3 +77,5 @@ ignore_missing_imports = True
7577
ignore_missing_imports = True
7678
[mypy-sgkit_bgen.tests.*]
7779
disallow_untyped_defs = False
80+
[mypy-sgkit_bgen.*]
81+
allow_redefinition = True

sgkit_bgen/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .bgen_reader import read_bgen, rechunk_from_zarr, rechunk_to_zarr # noqa: F401
1+
from .bgen_reader import bgen_to_zarr, read_bgen, rechunk_bgen # noqa: F401
22

3-
__all__ = ["read_bgen", "rechunk_from_zarr", "rechunk_to_zarr"]
3+
__all__ = ["read_bgen", "rechunk_bgen", "bgen_to_zarr"]

sgkit_bgen/bgen_reader.py

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""BGEN reader implementation (using bgen_reader)"""
2+
import tempfile
23
from pathlib import Path
3-
from typing import Any, Dict, Hashable, MutableMapping, Optional, Tuple, Union
4+
from typing import Any, Dict, Hashable, Mapping, MutableMapping, Optional, Tuple, Union
45

56
import dask.array as da
67
import dask.dataframe as dd
@@ -12,11 +13,11 @@
1213
from bgen_reader._metafile import create_metafile
1314
from bgen_reader._reader import infer_metafile_filepath
1415
from bgen_reader._samples import generate_samples, read_samples_file
16+
from rechunker import api as rechunker_api
1517
from xarray import Dataset
16-
from xarray.backends.zarr import ZarrStore
1718

1819
from sgkit import create_genotype_dosage_dataset
19-
from sgkit.typing import ArrayLike
20+
from sgkit.typing import ArrayLike, DType
2021
from sgkit.utils import encode_array
2122

2223
PathType = Union[str, Path]
@@ -241,6 +242,8 @@ def read_bgen(
241242

242243
def encode_variables(
243244
ds: Dataset,
245+
chunk_length: int,
246+
chunk_width: int,
244247
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
245248
probability_dtype: Optional[Any] = "uint8",
246249
) -> Dict[Hashable, Dict[str, Any]]:
@@ -249,6 +252,8 @@ def encode_variables(
249252
e = {}
250253
if compressor is not None:
251254
e.update({"compressor": compressor})
255+
if v in GT_DATA_VARS:
256+
e.update({"chunks": (chunk_length, chunk_width) + ds[v].shape[2:]})
252257
if probability_dtype is not None and v == "call_genotype_probability":
253258
dtype = np.dtype(probability_dtype)
254259
# Xarray will decode into float32 so any int greater than
@@ -287,16 +292,16 @@ def pack_variables(ds: Dataset) -> Dataset:
287292
return ds
288293

289294

290-
def unpack_variables(ds: Dataset, dtype: Any = "float32") -> Dataset:
295+
def unpack_variables(ds: Dataset, dtype: DType = "float32") -> Dataset:
291296
# Restore homozygous reference GP
292-
gp = ds["call_genotype_probability"].astype(dtype)
297+
gp = ds["call_genotype_probability"].astype(dtype) # type: ignore[no-untyped-call]
293298
if gp.sizes["genotypes"] != 2:
294299
raise ValueError(
295300
"Expecting variable 'call_genotype_probability' to have genotypes "
296301
f"dimension of size 2 (received sizes = {dict(gp.sizes)})"
297302
)
298303
ds = ds.drop_vars("call_genotype_probability")
299-
ds["call_genotype_probability"] = xr.concat( # type: ignore[no-untyped-call]
304+
ds["call_genotype_probability"] = xr.concat(
300305
[1 - gp.sum(dim="genotypes", skipna=False), gp], dim="genotypes"
301306
)
302307

@@ -309,44 +314,78 @@ def unpack_variables(ds: Dataset, dtype: Any = "float32") -> Dataset:
309314
return ds
310315

311316

312-
def rechunk_to_zarr(
317+
def rechunk_bgen(
313318
ds: Dataset,
314-
store: Union[PathType, MutableMapping[str, bytes]],
319+
output: Union[PathType, MutableMapping[str, bytes]],
315320
*,
316-
mode: str = "w",
317321
chunk_length: int = 10_000,
318-
chunk_width: int = 10_000,
322+
chunk_width: int = 1_000,
319323
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
320-
probability_dtype: Optional[Any] = "uint8",
324+
probability_dtype: Optional[DType] = "uint8",
325+
max_mem: str = "4GB",
321326
pack: bool = True,
322-
compute: bool = True,
323-
) -> ZarrStore:
327+
tempdir: Optional[PathType] = None,
328+
) -> Dataset:
329+
if isinstance(output, Path):
330+
output = str(output)
331+
332+
chunk_length = min(chunk_length, ds.dims["variants"])
333+
chunk_width = min(chunk_width, ds.dims["samples"])
334+
324335
if pack:
325336
ds = pack_variables(ds)
326-
for v in set(GT_DATA_VARS) & set(ds):
327-
chunk_size = da.asarray(ds[v]).chunksize[0]
328-
if chunk_length % chunk_size != 0:
329-
raise ValueError(
330-
f"Chunk size in variant dimension for variable '{v}' ({chunk_size}) "
331-
f"must evenly divide target chunk size {chunk_length}"
332-
)
333-
ds[v] = ds[v].chunk(chunks=dict(samples=chunk_width)) # type: ignore[dict-item]
337+
334338
encoding = encode_variables(
335-
ds, compressor=compressor, probability_dtype=probability_dtype
339+
ds,
340+
chunk_length=chunk_length,
341+
chunk_width=chunk_width,
342+
compressor=compressor,
343+
probability_dtype=probability_dtype,
336344
)
337-
return ds.to_zarr(store, mode=mode, encoding=encoding or None, compute=compute) # type: ignore[arg-type]
345+
with tempfile.TemporaryDirectory(
346+
prefix="bgen_to_zarr_", suffix=".zarr", dir=tempdir
347+
) as tmpdir:
348+
rechunked = rechunker_api.rechunk_dataset(
349+
ds,
350+
encoding=encoding,
351+
max_mem=max_mem,
352+
target_store=output,
353+
temp_store=tmpdir,
354+
executor="dask",
355+
)
356+
rechunked.execute()
357+
358+
ds: Dataset = xr.open_zarr(output, concat_characters=False) # type: ignore[no-untyped-call]
359+
if pack:
360+
ds = unpack_variables(ds)
361+
362+
return ds
338363

339364

340-
def rechunk_from_zarr(
341-
store: Union[PathType, MutableMapping[str, bytes]],
365+
def bgen_to_zarr(
366+
input: PathType,
367+
output: Union[PathType, MutableMapping[str, bytes]],
368+
region: Optional[Mapping[Hashable, Any]] = None,
342369
chunk_length: int = 10_000,
343-
chunk_width: int = 10_000,
344-
mask_and_scale: bool = True,
370+
chunk_width: int = 1_000,
371+
temp_chunk_length: int = 100,
372+
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
373+
probability_dtype: Optional[DType] = "uint8",
374+
max_mem: str = "4GB",
375+
pack: bool = True,
376+
tempdir: Optional[PathType] = None,
345377
) -> Dataset:
346-
# Always use concat_characters=False to avoid https://github.com/pydata/xarray/issues/4405
347-
ds = xr.open_zarr(store, mask_and_scale=mask_and_scale, concat_characters=False) # type: ignore[no-untyped-call]
348-
for v in set(GT_DATA_VARS) & set(ds):
349-
ds[v] = ds[v].chunk(chunks=dict(variants=chunk_length, samples=chunk_width))
350-
# Workaround for https://github.com/pydata/xarray/issues/4380
351-
del ds[v].encoding["chunks"]
352-
return ds # type: ignore[no-any-return]
378+
ds = read_bgen(input, chunks=(temp_chunk_length, -1, -1))
379+
if region is not None:
380+
ds = ds.isel(indexers=region)
381+
return rechunk_bgen(
382+
ds,
383+
output,
384+
chunk_length=chunk_length,
385+
chunk_width=chunk_width,
386+
compressor=compressor,
387+
probability_dtype=probability_dtype,
388+
max_mem=max_mem,
389+
pack=pack,
390+
tempdir=tempdir,
391+
)

sgkit_bgen/tests/test_bgen_reader.py

Lines changed: 53 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from sgkit_bgen.bgen_reader import (
1010
GT_DATA_VARS,
1111
BgenReader,
12-
rechunk_from_zarr,
13-
rechunk_to_zarr,
12+
bgen_to_zarr,
13+
rechunk_bgen,
1414
unpack_variables,
1515
)
1616

@@ -44,19 +44,27 @@
4444
[np.nan, 1.018, 0.010, 0.160, 0.991] # Generated using bgen-reader directly
4545
)
4646

47+
EXPECTED_DIMS = dict(variants=199, samples=500, genotypes=3, alleles=2)
48+
49+
50+
def _shape(*dims: str) -> Tuple[int, ...]:
51+
return tuple(EXPECTED_DIMS[d] for d in dims)
52+
4753

4854
@pytest.mark.parametrize("chunks", CHUNKS)
4955
def test_read_bgen(shared_datadir, chunks):
5056
path = shared_datadir / "example.bgen"
5157
ds = read_bgen(path, chunks=chunks)
5258

5359
# check some of the data (in different chunks)
54-
assert ds["call_dosage"].shape == (199, 500)
60+
assert ds["call_dosage"].shape == _shape("variants", "samples")
5561
npt.assert_almost_equal(ds["call_dosage"].values[1][0], 1.987, decimal=3)
5662
npt.assert_almost_equal(ds["call_dosage"].values[100][0], 0.160, decimal=3)
5763
npt.assert_array_equal(ds["call_dosage_mask"].values[0, 0], [True])
5864
npt.assert_array_equal(ds["call_dosage_mask"].values[0, 1], [False])
59-
assert ds["call_genotype_probability"].shape == (199, 500, 3)
65+
assert ds["call_genotype_probability"].shape == _shape(
66+
"variants", "samples", "genotypes"
67+
)
6068
npt.assert_almost_equal(
6169
ds["call_genotype_probability"].values[1][0], [0.005, 0.002, 0.992], decimal=3
6270
)
@@ -137,39 +145,45 @@ def test_read_bgen__raise_on_invalid_indexers(shared_datadir):
137145
reader[([0], [0], [0])]
138146

139147

140-
def _rechunk_to_zarr(
148+
def _rechunk_bgen(
141149
shared_datadir: Path, tmp_path: Path, **kwargs: Any
142-
) -> Tuple[xr.Dataset, str]:
150+
) -> Tuple[xr.Dataset, xr.Dataset, str]:
143151
path = shared_datadir / "example.bgen"
144152
ds = read_bgen(path, chunks=(10, -1, -1))
145153
store = tmp_path / "example.zarr"
146-
rechunk_to_zarr(ds, store, **kwargs)
147-
return ds, str(store)
154+
dsr = rechunk_bgen(ds, store, **kwargs)
155+
return ds, dsr, str(store)
148156

149157

150158
def _open_zarr(store: str, **kwargs: Any) -> xr.Dataset:
151159
# Force concat_characters False to avoid to avoid https://github.com/pydata/xarray/issues/4405
152160
return xr.open_zarr(store, concat_characters=False, **kwargs) # type: ignore[no-any-return,no-untyped-call]
153161

154162

155-
@pytest.mark.parametrize("chunk_width", [10, 50, 500])
156-
def test_rechunk_to_zarr__chunk_size(shared_datadir, tmp_path, chunk_width):
157-
_, store = _rechunk_to_zarr(
158-
shared_datadir, tmp_path, chunk_width=chunk_width, pack=False
163+
@pytest.mark.parametrize("target_chunks", [(10, 10), (50, 50), (100, 50), (50, 100)])
164+
def test_rechunk_bgen__target_chunks(shared_datadir, tmp_path, target_chunks):
165+
_, dsr, store = _rechunk_bgen(
166+
shared_datadir,
167+
tmp_path,
168+
chunk_length=target_chunks[0],
169+
chunk_width=target_chunks[1],
170+
pack=False,
159171
)
160-
dsr = _open_zarr(store)
161172
for v in GT_DATA_VARS:
162-
# Chunks shape should equal (
163-
# length of chunks on read,
164-
# width of chunks on rechunk
165-
# )
166-
assert dsr[v].data.chunksize[0] == 10
167-
assert dsr[v].data.chunksize[1] == chunk_width
173+
assert dsr[v].data.chunksize[:2] == target_chunks
174+
175+
176+
def test_rechunk_from_zarr__self_consistent(shared_datadir, tmp_path):
177+
# With no probability dtype or packing, rechunk_{to,from}_zarr is a noop
178+
ds, dsr, store = _rechunk_bgen(
179+
shared_datadir, tmp_path, probability_dtype=None, pack=False
180+
)
181+
xr.testing.assert_allclose(ds.compute(), dsr.compute()) # type: ignore[no-untyped-call]
168182

169183

170184
@pytest.mark.parametrize("dtype", ["uint8", "uint16"])
171-
def test_rechunk_to_zarr__probability_encoding(shared_datadir, tmp_path, dtype):
172-
ds, store = _rechunk_to_zarr(
185+
def test_rechunk_bgen__probability_encoding(shared_datadir, tmp_path, dtype):
186+
ds, _, store = _rechunk_bgen(
173187
shared_datadir, tmp_path, probability_dtype=dtype, pack=False
174188
)
175189
dsr = _open_zarr(store, mask_and_scale=False)
@@ -184,61 +198,42 @@ def test_rechunk_to_zarr__probability_encoding(shared_datadir, tmp_path, dtype):
184198
np.testing.assert_allclose(ds[v], dsr[v], atol=tolerance)
185199

186200

187-
def test_rechunk_to_zarr__variable_packing(shared_datadir, tmp_path):
188-
ds, store = _rechunk_to_zarr(
201+
def test_rechunk_bgen__variable_packing(shared_datadir, tmp_path):
202+
ds, dsr, store = _rechunk_bgen(
189203
shared_datadir, tmp_path, probability_dtype=None, pack=True
190204
)
191-
dsr = _open_zarr(store, mask_and_scale=True)
192-
dsr = unpack_variables(dsr)
193205
# A minor tolerance is necessary here when packing is enabled
194206
# because one of the genotype probabilities is constructed from the others
195207
xr.testing.assert_allclose(ds.compute(), dsr.compute(), atol=1e-6) # type: ignore[no-untyped-call]
196208

197209

198-
def test_rechunk_to_zarr__raise_on_invalid_chunk_length(shared_datadir, tmp_path):
199-
with pytest.raises(
200-
ValueError,
201-
match="Chunk size in variant dimension for variable .* must evenly divide target chunk size",
202-
):
203-
_rechunk_to_zarr(shared_datadir, tmp_path, chunk_length=11)
204-
205-
206-
@pytest.mark.parametrize("chunks", [(10, 10), (50, 50), (100, 50), (50, 100)])
207-
def test_rechunk_from_zarr__target_chunks(shared_datadir, tmp_path, chunks):
208-
ds, store = _rechunk_to_zarr(
209-
shared_datadir,
210-
tmp_path,
211-
chunk_length=chunks[0],
212-
chunk_width=chunks[1],
213-
pack=False,
214-
)
215-
ds = rechunk_from_zarr(store, chunk_length=chunks[0], chunk_width=chunks[1])
216-
for v in GT_DATA_VARS:
217-
assert ds[v].data.chunksize[:2] == chunks
218-
219-
220210
@pytest.mark.parametrize("dtype", ["uint32", "int8", "float32"])
221-
def test_rechunk_from_zarr__invalid_probability_type(shared_datadir, tmp_path, dtype):
211+
def test_rechunk_bgen__invalid_probability_type(shared_datadir, tmp_path, dtype):
222212
with pytest.raises(ValueError, match="Probability integer dtype invalid"):
223-
_rechunk_to_zarr(shared_datadir, tmp_path, probability_dtype=dtype)
213+
_rechunk_bgen(shared_datadir, tmp_path, probability_dtype=dtype)
224214

225215

226216
def test_unpack_variables__invalid_gp_dims(shared_datadir, tmp_path):
227217
# Validate that an error is thrown when variables are
228218
# unpacked without being packed in the first place
229-
_, store = _rechunk_to_zarr(shared_datadir, tmp_path, pack=False)
230-
dsr = _open_zarr(store, mask_and_scale=True)
219+
_, dsr, store = _rechunk_bgen(shared_datadir, tmp_path, pack=False)
231220
with pytest.raises(
232221
ValueError,
233222
match="Expecting variable 'call_genotype_probability' to have genotypes dimension of size 2",
234223
):
235224
unpack_variables(dsr)
236225

237226

238-
def test_rechunk_from_zarr__self_consistent(shared_datadir, tmp_path):
239-
# With no probability dtype or packing, rechunk_{to,from}_zarr is a noop
240-
ds, store = _rechunk_to_zarr(
241-
shared_datadir, tmp_path, probability_dtype=None, pack=False
242-
)
243-
dsr = rechunk_from_zarr(store)
244-
xr.testing.assert_allclose(ds.compute(), dsr.compute()) # type: ignore[no-untyped-call]
227+
@pytest.mark.parametrize("region", [None, dict(variants=slice(0, 100))])
228+
def test_bgen_to_zarr(shared_datadir, tmp_path, region):
229+
input = shared_datadir / "example.bgen"
230+
output = tmp_path / "example.zarr"
231+
ds = bgen_to_zarr(input, output, region=region)
232+
expected_dims = {
233+
k: EXPECTED_DIMS[k]
234+
if region is None or k not in region
235+
else region[k].stop - region[k].start
236+
for k in EXPECTED_DIMS
237+
}
238+
actual_dims = {k: v for k, v in ds.dims.items() if k in expected_dims}
239+
assert actual_dims == expected_dims

0 commit comments

Comments
 (0)