Skip to content

Commit 3cba1a7

Browse files
authored
Merge pull request #16 from janelia-cellmap/groupspec_from_zarr_fix
groupspec from zarr fix
2 parents 4c8fba6 + 095a052 commit 3cba1a7

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

src/pydantic_zarr/core.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Optional,
99
TypeVar,
1010
Union,
11+
overload,
1112
)
1213
from pydantic import BaseModel, root_validator, validator
1314
from pydantic.generics import GenericModel
@@ -141,7 +142,7 @@ def from_zarr(cls, zarray: zarr.Array):
141142
filters=zarray.filters,
142143
dimension_separator=zarray._dimension_separator,
143144
compressor=zarray.compressor,
144-
attrs=dict(zarray.attrs),
145+
attrs=zarray.attrs.asdict(),
145146
)
146147

147148
def to_zarr(
@@ -209,7 +210,7 @@ def from_zarr(cls, group: zarr.Group) -> "GroupSpec[TAttr, TItem]":
209210
if isinstance(member, zarr.Array):
210211
_item = ArraySpec.from_zarr(member)
211212
elif isinstance(member, zarr.Group):
212-
_item = cls.from_zarr(member)
213+
_item = GroupSpec.from_zarr(member)
213214
else:
214215
msg = f"""
215216
Unparseable object encountered: {type(member)}. Expected zarr.Array or
@@ -218,7 +219,7 @@ def from_zarr(cls, group: zarr.Group) -> "GroupSpec[TAttr, TItem]":
218219
raise ValueError(msg)
219220
members[name] = _item
220221

221-
result = cls(attrs=dict(group.attrs), members=members)
222+
result = cls(attrs=group.attrs.asdict(), members=members)
222223
return result
223224

224225
def to_zarr(self, store: BaseStore, path: str, overwrite: bool = False):
@@ -259,6 +260,16 @@ def to_zarr(self, store: BaseStore, path: str, overwrite: bool = False):
259260
return result
260261

261262

263+
@overload
264+
def from_zarr(element: zarr.Array) -> ArraySpec:
265+
...
266+
267+
268+
@overload
269+
def from_zarr(element: zarr.Group) -> GroupSpec:
270+
...
271+
272+
262273
def from_zarr(element: Union[zarr.Array, zarr.Group]) -> Union[ArraySpec, GroupSpec]:
263274
"""
264275
Recursively parse a Zarr group or Zarr array into an ArraySpec or GroupSpec.
@@ -290,7 +301,7 @@ def from_zarr(element: Union[zarr.Array, zarr.Group]) -> Union[ArraySpec, GroupS
290301
raise ValueError(msg)
291302
members[name] = _item
292303

293-
result = GroupSpec(attrs=dict(element.attrs), members=members)
304+
result = GroupSpec(attrs=element.attrs.asdict(), members=members)
294305
return result
295306
else:
296307
msg = f"""

tests/test_core.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import zarr
44
from zarr.errors import ContainsGroupError
5-
from typing import Any, Literal, TypedDict
5+
from typing import Any, Literal, TypedDict, Union
66
import numcodecs
77
from pydantic_zarr.core import ArraySpec, GroupSpec, to_zarr, from_zarr
88
import numpy as np
@@ -150,12 +150,14 @@ class SubGroupAttrs(TypedDict):
150150
a: str
151151
b: float
152152

153+
SubGroup = GroupSpec[SubGroupAttrs, Any]
154+
153155
class ArrayAttrs(TypedDict):
154156
scale: list[float]
155157

156158
store = zarr.MemoryStore()
157159

158-
spec = GroupSpec(
160+
spec = GroupSpec[RootAttrs, Union[ArraySpec, SubGroup]](
159161
attrs=RootAttrs(foo=10, bar=[0, 1, 2]),
160162
members={
161163
"s0": ArraySpec(
@@ -178,7 +180,7 @@ class ArrayAttrs(TypedDict):
178180
dimension_separator=dimension_separator,
179181
attrs=ArrayAttrs(scale=[2.0]),
180182
),
181-
"subgroup": GroupSpec(attrs=SubGroupAttrs(a="foo", b=1.0)),
183+
"subgroup": SubGroup(attrs=SubGroupAttrs(a="foo", b=1.0)),
182184
},
183185
)
184186
# materialize a zarr group, based on the spec
@@ -195,6 +197,11 @@ class ArrayAttrs(TypedDict):
195197
group2 = to_zarr(spec, store, "/group_a", overwrite=True)
196198
assert group2 == group
197199

200+
# again with class methods
201+
group3 = spec.to_zarr(store, "/group_b")
202+
observed = spec.from_zarr(group3)
203+
assert observed == spec
204+
198205

199206
def test_shape_chunks():
200207
for a, b in zip(range(1, 5), range(2, 6)):

0 commit comments

Comments
 (0)