14
14
from .settings import config
15
15
16
16
17
- mxClassID = dict (
18
- (
19
- # see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
20
- ("mxUNKNOWN_CLASS" , None ),
21
- ("mxCELL_CLASS" , None ),
22
- ("mxSTRUCT_CLASS" , None ),
23
- ("mxLOGICAL_CLASS" , np .dtype ("bool" )),
24
- ("mxCHAR_CLASS" , np .dtype ("c" )),
25
- ("mxVOID_CLASS" , np .dtype ("O" )),
26
- ("mxDOUBLE_CLASS" , np .dtype ("float64" )),
27
- ("mxSINGLE_CLASS" , np .dtype ("float32" )),
28
- ("mxINT8_CLASS" , np .dtype ("int8" )),
29
- ("mxUINT8_CLASS" , np .dtype ("uint8" )),
30
- ("mxINT16_CLASS" , np .dtype ("int16" )),
31
- ("mxUINT16_CLASS" , np .dtype ("uint16" )),
32
- ("mxINT32_CLASS" , np .dtype ("int32" )),
33
- ("mxUINT32_CLASS" , np .dtype ("uint32" )),
34
- ("mxINT64_CLASS" , np .dtype ("int64" )),
35
- ("mxUINT64_CLASS" , np .dtype ("uint64" )),
36
- ("mxFUNCTION_CLASS" , None ),
37
- )
38
- )
39
-
40
- rev_class_id = {dtype : i for i , dtype in enumerate (mxClassID .values ())}
41
- dtype_list = list (mxClassID .values ())
42
- type_names = list (mxClassID )
17
+ deserialize_lookup = {
18
+ 0 : {"dtype" : None , "scalar_type" : "UNKNOWN" },
19
+ 1 : {"dtype" : None , "scalar_type" : "CELL" },
20
+ 2 : {"dtype" : None , "scalar_type" : "STRUCT" },
21
+ 3 : {"dtype" : np .dtype ("bool" ), "scalar_type" : "LOGICAL" },
22
+ 4 : {"dtype" : np .dtype ("c" ), "scalar_type" : "CHAR" },
23
+ 5 : {"dtype" : np .dtype ("O" ), "scalar_type" : "VOID" },
24
+ 6 : {"dtype" : np .dtype ("float64" ), "scalar_type" : "DOUBLE" },
25
+ 7 : {"dtype" : np .dtype ("float32" ), "scalar_type" : "SINGLE" },
26
+ 8 : {"dtype" : np .dtype ("int8" ), "scalar_type" : "INT8" },
27
+ 9 : {"dtype" : np .dtype ("uint8" ), "scalar_type" : "UINT8" },
28
+ 10 : {"dtype" : np .dtype ("int16" ), "scalar_type" : "INT16" },
29
+ 11 : {"dtype" : np .dtype ("uint16" ), "scalar_type" : "UINT16" },
30
+ 12 : {"dtype" : np .dtype ("int32" ), "scalar_type" : "INT32" },
31
+ 13 : {"dtype" : np .dtype ("uint32" ), "scalar_type" : "UINT32" },
32
+ 14 : {"dtype" : np .dtype ("int64" ), "scalar_type" : "INT64" },
33
+ 15 : {"dtype" : np .dtype ("uint64" ), "scalar_type" : "UINT64" },
34
+ 16 : {"dtype" : None , "scalar_type" : "FUNCTION" },
35
+ 65_536 : {"dtype" : np .dtype ("datetime64[Y]" ), "scalar_type" : "DATETIME64[Y]" },
36
+ 65_537 : {"dtype" : np .dtype ("datetime64[M]" ), "scalar_type" : "DATETIME64[M]" },
37
+ 65_538 : {"dtype" : np .dtype ("datetime64[W]" ), "scalar_type" : "DATETIME64[W]" },
38
+ 65_539 : {"dtype" : np .dtype ("datetime64[D]" ), "scalar_type" : "DATETIME64[D]" },
39
+ 65_540 : {"dtype" : np .dtype ("datetime64[h]" ), "scalar_type" : "DATETIME64[h]" },
40
+ 65_541 : {"dtype" : np .dtype ("datetime64[m]" ), "scalar_type" : "DATETIME64[m]" },
41
+ 65_542 : {"dtype" : np .dtype ("datetime64[s]" ), "scalar_type" : "DATETIME64[s]" },
42
+ 65_543 : {"dtype" : np .dtype ("datetime64[ms]" ), "scalar_type" : "DATETIME64[ms]" },
43
+ 65_544 : {"dtype" : np .dtype ("datetime64[us]" ), "scalar_type" : "DATETIME64[us]" },
44
+ 65_545 : {"dtype" : np .dtype ("datetime64[ns]" ), "scalar_type" : "DATETIME64[ns]" },
45
+ 65_546 : {"dtype" : np .dtype ("datetime64[ps]" ), "scalar_type" : "DATETIME64[ps]" },
46
+ 65_547 : {"dtype" : np .dtype ("datetime64[fs]" ), "scalar_type" : "DATETIME64[fs]" },
47
+ 65_548 : {"dtype" : np .dtype ("datetime64[as]" ), "scalar_type" : "DATETIME64[as]" },
48
+ }
49
+ serialize_lookup = {
50
+ v ["dtype" ]: {"type_id" : k , "scalar_type" : v ["scalar_type" ]}
51
+ for k , v in deserialize_lookup .items ()
52
+ if v ["dtype" ] is not None
53
+ }
54
+
43
55
44
56
compression = {b"ZL123\0 " : zlib .decompress }
45
57
@@ -176,7 +188,7 @@ def pack_blob(self, obj):
176
188
return self .pack_float (obj )
177
189
if isinstance (obj , np .ndarray ) and obj .dtype .fields :
178
190
return self .pack_recarray (np .array (obj ))
179
- if isinstance (obj , np .number ):
191
+ if isinstance (obj , ( np .number , np . datetime64 ) ):
180
192
return self .pack_array (np .array (obj ))
181
193
if isinstance (obj , (bool , np .bool_ )):
182
194
return self .pack_array (np .array (obj ))
@@ -211,14 +223,18 @@ def read_array(self):
211
223
shape = self .read_value (count = n_dims )
212
224
n_elem = np .prod (shape , dtype = int )
213
225
dtype_id , is_complex = self .read_value ("uint32" , 2 )
214
- dtype = dtype_list [dtype_id ]
215
226
216
- if type_names [dtype_id ] == "mxVOID_CLASS" :
227
+ # Get dtype from type id
228
+ dtype = deserialize_lookup [dtype_id ]["dtype" ]
229
+
230
+ # Check if name is void
231
+ if deserialize_lookup [dtype_id ]["scalar_type" ] == "VOID" :
217
232
data = np .array (
218
233
list (self .read_blob (self .read_value ()) for _ in range (n_elem )),
219
234
dtype = np .dtype ("O" ),
220
235
)
221
- elif type_names [dtype_id ] == "mxCHAR_CLASS" :
236
+ # Check if name is char
237
+ elif deserialize_lookup [dtype_id ]["scalar_type" ] == "CHAR" :
222
238
# compensate for MATLAB packing of char arrays
223
239
data = self .read_value (dtype , count = 2 * n_elem )
224
240
data = data [::2 ].astype ("U1" )
@@ -240,6 +256,8 @@ def pack_array(self, array):
240
256
"""
241
257
Serialize an np.ndarray into bytes. Scalars are encoded with ndim=0.
242
258
"""
259
+ if "datetime64" in array .dtype .name :
260
+ self .set_dj0 ()
243
261
blob = (
244
262
b"A"
245
263
+ np .uint64 (array .ndim ).tobytes ()
@@ -248,22 +266,26 @@ def pack_array(self, array):
248
266
is_complex = np .iscomplexobj (array )
249
267
if is_complex :
250
268
array , imaginary = np .real (array ), np .imag (array )
251
- type_id = (
252
- rev_class_id [array .dtype ]
253
- if array .dtype .char != "U"
254
- else rev_class_id [np .dtype ("O" )]
255
- )
256
- if dtype_list [type_id ] is None :
257
- raise DataJointError ("Type %s is ambiguous or unknown" % array .dtype )
269
+ try :
270
+ type_id = serialize_lookup [array .dtype ]["type_id" ]
271
+ except KeyError :
272
+ # U is for unicode string
273
+ if array .dtype .char == "U" :
274
+ type_id = serialize_lookup [np .dtype ("O" )]["type_id" ]
275
+ else :
276
+ raise DataJointError (f"Type { array .dtype } is ambiguous or unknown" )
258
277
259
278
blob += np .array ([type_id , is_complex ], dtype = np .uint32 ).tobytes ()
260
- if type_names [type_id ] == "mxVOID_CLASS" : # array of dtype('O')
279
+ if (
280
+ array .dtype .char == "U"
281
+ or serialize_lookup [array .dtype ]["scalar_type" ] == "VOID"
282
+ ):
261
283
blob += b"" .join (
262
284
len_u64 (it ) + it
263
285
for it in (self .pack_blob (e ) for e in array .flatten (order = "F" ))
264
286
)
265
287
self .set_dj0 () # not supported by original mym
266
- elif type_names [ type_id ] == "mxCHAR_CLASS" : # array of dtype('c')
288
+ elif serialize_lookup [ array . dtype ][ "scalar_type" ] == "CHAR" :
267
289
blob += (
268
290
array .view (np .uint8 ).astype (np .uint16 ).tobytes ()
269
291
) # convert to 16-bit chars for MATLAB
0 commit comments