Skip to content

Commit cf55cab

Browse files
committed
chore: add overloaded type annotation and change some instances of dict(attrs) to attrs.asdict()
1 parent affb65d commit cf55cab

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

src/pydantic_zarr/core.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Optional,
99
TypeVar,
1010
Union,
11+
overload,
1112
)
1213
from pydantic import BaseModel, root_validator, validator
1314
from pydantic.generics import GenericModel
@@ -141,7 +142,7 @@ def from_zarr(cls, zarray: zarr.Array):
141142
filters=zarray.filters,
142143
dimension_separator=zarray._dimension_separator,
143144
compressor=zarray.compressor,
144-
attrs=dict(zarray.attrs),
145+
attrs=zarray.attrs.asdict(),
145146
)
146147

147148
def to_zarr(
@@ -209,7 +210,7 @@ def from_zarr(cls, group: zarr.Group) -> "GroupSpec[TAttr, TItem]":
209210
if isinstance(member, zarr.Array):
210211
_item = ArraySpec.from_zarr(member)
211212
elif isinstance(member, zarr.Group):
212-
_item = cls.from_zarr(member)
213+
_item = GroupSpec.from_zarr(member)
213214
else:
214215
msg = f"""
215216
Unparseable object encountered: {type(member)}. Expected zarr.Array or
@@ -218,7 +219,7 @@ def from_zarr(cls, group: zarr.Group) -> "GroupSpec[TAttr, TItem]":
218219
raise ValueError(msg)
219220
members[name] = _item
220221

221-
result = cls(attrs=dict(group.attrs), members=members)
222+
result = cls(attrs=group.attrs.asdict(), members=members)
222223
return result
223224

224225
def to_zarr(self, store: BaseStore, path: str, overwrite: bool = False):
@@ -259,6 +260,16 @@ def to_zarr(self, store: BaseStore, path: str, overwrite: bool = False):
259260
return result
260261

261262

263+
@overload
264+
def from_zarr(element: zarr.Array) -> ArraySpec:
265+
...
266+
267+
268+
@overload
269+
def from_zarr(element: zarr.Group) -> GroupSpec:
270+
...
271+
272+
262273
def from_zarr(element: Union[zarr.Array, zarr.Group]) -> Union[ArraySpec, GroupSpec]:
263274
"""
264275
Recursively parse a Zarr group or Zarr array into an ArraySpec or GroupSpec.
@@ -290,7 +301,7 @@ def from_zarr(element: Union[zarr.Array, zarr.Group]) -> Union[ArraySpec, GroupS
290301
raise ValueError(msg)
291302
members[name] = _item
292303

293-
result = GroupSpec(attrs=dict(element.attrs), members=members)
304+
result = GroupSpec(attrs=element.attrs.asdict(), members=members)
294305
return result
295306
else:
296307
msg = f"""

0 commit comments

Comments
 (0)