diff --git a/flax/serialization.py b/flax/serialization.py index bd6d0853cc..8894fe66cd 100644 --- a/flax/serialization.py +++ b/flax/serialization.py @@ -259,10 +259,30 @@ def _ndarray_to_bytes(arr) -> bytes: return msgpack.packb(tpl, use_bin_type=True) -def _dtype_from_name(name: str): - """Handle JAX bfloat16 dtype correctly.""" +def _dtype_from_name(name: bytes): + """Handle JAX bfloat16 and other numpy fixed-width dtypes correctly.""" + + def _parse_bit_len(name: bytes, dtype_name: bytes): + return int(name.replace(dtype_name, b'')) + if name == b'bfloat16': return jax.numpy.bfloat16 + elif name.startswith(b'str'): + string_dtype = np.asarray('x').dtype + # Typically '