diff --git a/flax/serialization.py b/flax/serialization.py index bd6d0853cc..e75e362c89 100644 --- a/flax/serialization.py +++ b/flax/serialization.py @@ -263,6 +263,9 @@ def _dtype_from_name(name: str): """Handle JAX bfloat16 dtype correctly.""" if name == b'bfloat16': return jax.numpy.bfloat16 + elif name[0:4] == b'void': + # decode voidNNNN, where NNNN is the length in bits + return np.void(int(name[4:]) // 8) else: return np.dtype(name) @@ -343,6 +346,10 @@ def _np_convert_in_place(d): def _chunk(arr) -> dict[str, Any]: """Convert array to a canonical dictionary of chunked arrays.""" + if isinstance(arr.dtype, np.dtypes.VoidDType): + # The chunking strategy below doesn't work for a large scalar np.void. + # Implement an alternative strategy later if needed. + raise NotImplementedError('Chunking not implemented for np.void dtype') chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize)) data = {'__msgpack_chunked_array__': True, 'shape': _tuple_to_dict(arr.shape)} flatarr = arr.reshape(-1) diff --git a/tests/serialization_test.py b/tests/serialization_test.py index 05c319b6e8..ed05d0c486 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -264,6 +264,7 @@ def __call__(self): 'float', 'complex', 'bool', + 'void', ]) def test_numpy_serialization(self, dtype): np.random.seed(0)