diff --git a/ome_zarr/writer.py b/ome_zarr/writer.py index ba4ff916..8e9ca9c4 100644 --- a/ome_zarr/writer.py +++ b/ome_zarr/writer.py @@ -248,44 +248,37 @@ def write_multiscale( if chunks_opt is not None: chunks_opt = _retuple(chunks_opt, data.shape) + # v2 arguments + if fmt.zarr_format == 2: + options["chunks"] = chunks_opt + options["dimension_separator"] = "/" + # default to zstd compression + options["compressor"] = options.get( + "compressor", Blosc(cname="zstd", clevel=5, shuffle=Blosc.SHUFFLE) + ) + else: + if axes is not None: + options["dimension_names"] = [ + axis["name"] for axis in axes if isinstance(axis, dict) + ] + if isinstance(data, da.Array): + options["zarr_format"] = fmt.zarr_format if chunks_opt is not None: data = da.array(data).rechunk(chunks=chunks_opt) - options["chunks"] = chunks_opt da_delayed = da.to_zarr( arr=data, url=group.store, component=str(Path(group.path, str(path))), - storage_options=options, - # by default we use Blosc with zstd compression - compressor=options.get( - "compressor", Blosc(cname="zstd", clevel=5, shuffle=Blosc.SHUFFLE) - ), - # TODO: default dimension_separator? Not set in store for zarr v3 - # dimension_separator=group.store.dimension_separator, - dimension_separator="/", compute=compute, zarr_format=fmt.zarr_format, + **options, ) if not compute: dask_delayed.append(da_delayed) else: - # v2 arguments - if fmt.zarr_format == 2: - options["chunks"] = chunks_opt - options["dimension_separator"] = "/" - # default to zstd compression - options["compressor"] = options.get( - "compressor", Blosc(cname="zstd", clevel=5, shuffle=Blosc.SHUFFLE) - ) - else: - if axes is not None: - options["dimension_names"] = [ - axis["name"] for axis in axes if isinstance(axis, dict) - ] - options["shape"] = data.shape # otherwise we get 'null' options["fill_value"] = 0 @@ -649,21 +642,21 @@ def _write_dask_image( LOGGER.debug( "write dask.array to_zarr shape: %s, dtype: %s", image.shape, image.dtype ) + if fmt.zarr_format == 2: + options["dimension_separator"] = "/" + if options["compressor"] is None: + options["compressor"] = Blosc( + cname="zstd", clevel=5, shuffle=Blosc.SHUFFLE + ) + delayed.append( da.to_zarr( arr=image, url=group.store, component=str(Path(group.path, str(path))), - storage_options=options, compute=False, - compressor=options.pop( - "compressor", Blosc(cname="zstd", clevel=5, shuffle=Blosc.SHUFFLE) - ), - # TODO: default dimension_separator? Not set in store for zarr v3 - # dimension_separator=group.store.dimension_separator, - dimension_separator="/", - # TODO: hard-coded zarr_format for now. Needs to be set by the format.py - zarr_format=2, + zarr_format=fmt.zarr_format, + **options, ) ) datasets.append({"path": str(path)}) diff --git a/tests/test_writer.py b/tests/test_writer.py index d82cab23..7f7d8917 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -39,7 +39,8 @@ class TestWriter: @pytest.fixture(autouse=True) def initdir(self, tmpdir): self.path = pathlib.Path(tmpdir.mkdir("data")) - self.store = parse_url(self.path, mode="w").store + # All Zarr v2 formats tested below can use this store + self.store = parse_url(self.path, mode="w", fmt=FormatV04()).store self.root = zarr.group(store=self.store) self.group = self.root.create_group("test") @@ -140,12 +141,19 @@ def test_writer( assert np.allclose(data, node.data[0][...].compute()) @pytest.mark.parametrize("array_constructor", [np.array, da.from_array]) - def test_write_image_current(self, array_constructor): + def test_write_image_current(self, array_constructor, tmpdir): shape = (64, 64, 64) data = self.create_data(shape) data = array_constructor(data) - write_image(data, self.group, axes="zyx") - reader = Reader(parse_url(f"{self.path}/test")) + # don't use self.store etc as that is not current zarr format (v3) + test_path = pathlib.Path(tmpdir.mkdir("current")) + store = parse_url(test_path, mode="w").store + print("test_path", test_path) + root = zarr.group(store=store) + group = root.create_group("test") + write_image(data, group, axes="zyx") + # assert group is None + reader = Reader(parse_url(f"{test_path}/test")) image_node = list(reader())[0] for transfs in image_node.metadata["coordinateTransformations"]: assert len(transfs) == 1