Skip to content

Commit

Permalink
Add serialization support for np.void datatype
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706058750
  • Loading branch information
Flax Team committed Dec 14, 2024
1 parent fc38f21 commit 74a4c58
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
7 changes: 7 additions & 0 deletions flax/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def _dtype_from_name(name: str):
"""Handle JAX bfloat16 dtype correctly."""
if name == b'bfloat16':
return jax.numpy.bfloat16
elif name[0:4] == b'void':
# decode voidNNNN, where NNNN is the length in bits
return np.void(int(name[4:]) // 8)
else:
return np.dtype(name)

Expand Down Expand Up @@ -343,6 +346,10 @@ def _np_convert_in_place(d):

def _chunk(arr) -> dict[str, Any]:
"""Convert array to a canonical dictionary of chunked arrays."""
if isinstance(arr.dtype, np.dtypes.VoidDType):
# The chunking strategy below doesn't work for a large scalar np.void.
# Implement an alternative strategy later if needed.
raise NotImplementedError('Chunking not implemented for np.void dtype')
chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize))
data = {'__msgpack_chunked_array__': True, 'shape': _tuple_to_dict(arr.shape)}
flatarr = arr.reshape(-1)
Expand Down
1 change: 1 addition & 0 deletions tests/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __call__(self):
'float',
'complex',
'bool',
'void',
])
def test_numpy_serialization(self, dtype):
np.random.seed(0)
Expand Down

0 comments on commit 74a4c58

Please sign in to comment.