Skip to content

Commit b0f5dc4

Browse files
Merge pull request #1036 from jverswijver/numpy_datetime
add np.datetime64 serialization and tests
2 parents 07c5553 + a37cb1e commit b0f5dc4

File tree

7 files changed

+101
-47
lines changed

7 files changed

+101
-47
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
## Release notes
22

3+
### 0.13.7 -- Jul 13, 2022
4+
* Bugfix - Fix networkx incompatable change by version pinning to 2.6.3 PR #1036 (#1035)
5+
* Add - Support for serializing numpy datetime64 types PR #1036 (#1022)
6+
* Update - Add traceback to default logging PR #1036
7+
38
### 0.13.6 -- Jun 13, 2022
49
* Add - Config option to set threshold for when to stop using checksums for filepath stores. PR #1025
510
* Add - Unified package level logger for package (#667) PR #1031

datajoint/blob.py

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,44 @@
1414
from .settings import config
1515

1616

17-
mxClassID = dict(
18-
(
19-
# see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
20-
("mxUNKNOWN_CLASS", None),
21-
("mxCELL_CLASS", None),
22-
("mxSTRUCT_CLASS", None),
23-
("mxLOGICAL_CLASS", np.dtype("bool")),
24-
("mxCHAR_CLASS", np.dtype("c")),
25-
("mxVOID_CLASS", np.dtype("O")),
26-
("mxDOUBLE_CLASS", np.dtype("float64")),
27-
("mxSINGLE_CLASS", np.dtype("float32")),
28-
("mxINT8_CLASS", np.dtype("int8")),
29-
("mxUINT8_CLASS", np.dtype("uint8")),
30-
("mxINT16_CLASS", np.dtype("int16")),
31-
("mxUINT16_CLASS", np.dtype("uint16")),
32-
("mxINT32_CLASS", np.dtype("int32")),
33-
("mxUINT32_CLASS", np.dtype("uint32")),
34-
("mxINT64_CLASS", np.dtype("int64")),
35-
("mxUINT64_CLASS", np.dtype("uint64")),
36-
("mxFUNCTION_CLASS", None),
37-
)
38-
)
39-
40-
rev_class_id = {dtype: i for i, dtype in enumerate(mxClassID.values())}
41-
dtype_list = list(mxClassID.values())
42-
type_names = list(mxClassID)
17+
deserialize_lookup = {
18+
0: {"dtype": None, "scalar_type": "UNKNOWN"},
19+
1: {"dtype": None, "scalar_type": "CELL"},
20+
2: {"dtype": None, "scalar_type": "STRUCT"},
21+
3: {"dtype": np.dtype("bool"), "scalar_type": "LOGICAL"},
22+
4: {"dtype": np.dtype("c"), "scalar_type": "CHAR"},
23+
5: {"dtype": np.dtype("O"), "scalar_type": "VOID"},
24+
6: {"dtype": np.dtype("float64"), "scalar_type": "DOUBLE"},
25+
7: {"dtype": np.dtype("float32"), "scalar_type": "SINGLE"},
26+
8: {"dtype": np.dtype("int8"), "scalar_type": "INT8"},
27+
9: {"dtype": np.dtype("uint8"), "scalar_type": "UINT8"},
28+
10: {"dtype": np.dtype("int16"), "scalar_type": "INT16"},
29+
11: {"dtype": np.dtype("uint16"), "scalar_type": "UINT16"},
30+
12: {"dtype": np.dtype("int32"), "scalar_type": "INT32"},
31+
13: {"dtype": np.dtype("uint32"), "scalar_type": "UINT32"},
32+
14: {"dtype": np.dtype("int64"), "scalar_type": "INT64"},
33+
15: {"dtype": np.dtype("uint64"), "scalar_type": "UINT64"},
34+
16: {"dtype": None, "scalar_type": "FUNCTION"},
35+
65_536: {"dtype": np.dtype("datetime64[Y]"), "scalar_type": "DATETIME64[Y]"},
36+
65_537: {"dtype": np.dtype("datetime64[M]"), "scalar_type": "DATETIME64[M]"},
37+
65_538: {"dtype": np.dtype("datetime64[W]"), "scalar_type": "DATETIME64[W]"},
38+
65_539: {"dtype": np.dtype("datetime64[D]"), "scalar_type": "DATETIME64[D]"},
39+
65_540: {"dtype": np.dtype("datetime64[h]"), "scalar_type": "DATETIME64[h]"},
40+
65_541: {"dtype": np.dtype("datetime64[m]"), "scalar_type": "DATETIME64[m]"},
41+
65_542: {"dtype": np.dtype("datetime64[s]"), "scalar_type": "DATETIME64[s]"},
42+
65_543: {"dtype": np.dtype("datetime64[ms]"), "scalar_type": "DATETIME64[ms]"},
43+
65_544: {"dtype": np.dtype("datetime64[us]"), "scalar_type": "DATETIME64[us]"},
44+
65_545: {"dtype": np.dtype("datetime64[ns]"), "scalar_type": "DATETIME64[ns]"},
45+
65_546: {"dtype": np.dtype("datetime64[ps]"), "scalar_type": "DATETIME64[ps]"},
46+
65_547: {"dtype": np.dtype("datetime64[fs]"), "scalar_type": "DATETIME64[fs]"},
47+
65_548: {"dtype": np.dtype("datetime64[as]"), "scalar_type": "DATETIME64[as]"},
48+
}
49+
serialize_lookup = {
50+
v["dtype"]: {"type_id": k, "scalar_type": v["scalar_type"]}
51+
for k, v in deserialize_lookup.items()
52+
if v["dtype"] is not None
53+
}
54+
4355

4456
compression = {b"ZL123\0": zlib.decompress}
4557

@@ -176,7 +188,7 @@ def pack_blob(self, obj):
176188
return self.pack_float(obj)
177189
if isinstance(obj, np.ndarray) and obj.dtype.fields:
178190
return self.pack_recarray(np.array(obj))
179-
if isinstance(obj, np.number):
191+
if isinstance(obj, (np.number, np.datetime64)):
180192
return self.pack_array(np.array(obj))
181193
if isinstance(obj, (bool, np.bool_)):
182194
return self.pack_array(np.array(obj))
@@ -211,14 +223,18 @@ def read_array(self):
211223
shape = self.read_value(count=n_dims)
212224
n_elem = np.prod(shape, dtype=int)
213225
dtype_id, is_complex = self.read_value("uint32", 2)
214-
dtype = dtype_list[dtype_id]
215226

216-
if type_names[dtype_id] == "mxVOID_CLASS":
227+
# Get dtype from type id
228+
dtype = deserialize_lookup[dtype_id]["dtype"]
229+
230+
# Check if name is void
231+
if deserialize_lookup[dtype_id]["scalar_type"] == "VOID":
217232
data = np.array(
218233
list(self.read_blob(self.read_value()) for _ in range(n_elem)),
219234
dtype=np.dtype("O"),
220235
)
221-
elif type_names[dtype_id] == "mxCHAR_CLASS":
236+
# Check if name is char
237+
elif deserialize_lookup[dtype_id]["scalar_type"] == "CHAR":
222238
# compensate for MATLAB packing of char arrays
223239
data = self.read_value(dtype, count=2 * n_elem)
224240
data = data[::2].astype("U1")
@@ -240,6 +256,8 @@ def pack_array(self, array):
240256
"""
241257
Serialize an np.ndarray into bytes. Scalars are encoded with ndim=0.
242258
"""
259+
if "datetime64" in array.dtype.name:
260+
self.set_dj0()
243261
blob = (
244262
b"A"
245263
+ np.uint64(array.ndim).tobytes()
@@ -248,22 +266,26 @@ def pack_array(self, array):
248266
is_complex = np.iscomplexobj(array)
249267
if is_complex:
250268
array, imaginary = np.real(array), np.imag(array)
251-
type_id = (
252-
rev_class_id[array.dtype]
253-
if array.dtype.char != "U"
254-
else rev_class_id[np.dtype("O")]
255-
)
256-
if dtype_list[type_id] is None:
257-
raise DataJointError("Type %s is ambiguous or unknown" % array.dtype)
269+
try:
270+
type_id = serialize_lookup[array.dtype]["type_id"]
271+
except KeyError:
272+
# U is for unicode string
273+
if array.dtype.char == "U":
274+
type_id = serialize_lookup[np.dtype("O")]["type_id"]
275+
else:
276+
raise DataJointError(f"Type {array.dtype} is ambiguous or unknown")
258277

259278
blob += np.array([type_id, is_complex], dtype=np.uint32).tobytes()
260-
if type_names[type_id] == "mxVOID_CLASS": # array of dtype('O')
279+
if (
280+
array.dtype.char == "U"
281+
or serialize_lookup[array.dtype]["scalar_type"] == "VOID"
282+
):
261283
blob += b"".join(
262284
len_u64(it) + it
263285
for it in (self.pack_blob(e) for e in array.flatten(order="F"))
264286
)
265287
self.set_dj0() # not supported by original mym
266-
elif type_names[type_id] == "mxCHAR_CLASS": # array of dtype('c')
288+
elif serialize_lookup[array.dtype]["scalar_type"] == "CHAR":
267289
blob += (
268290
array.view(np.uint8).astype(np.uint16).tobytes()
269291
) # convert to 16-bit chars for MATLAB

datajoint/logging.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,7 @@ def excepthook(exc_type, exc_value, exc_traceback):
2121
sys.__excepthook__(exc_type, exc_value, exc_traceback)
2222
return
2323

24-
if logger.getEffectiveLevel() == 10:
25-
logger.debug(
26-
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
27-
)
28-
else:
29-
logger.error(f"Uncaught exception: {exc_value}")
24+
logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
3025

3126

3227
sys.excepthook = excepthook

datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.13.6"
1+
__version__ = "0.13.7"
22

33
assert len(__version__) <= 10 # The log table limits version to the 10 characters

docs-parts/intro/Releases_lang1.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
0.13.7 -- Jul 13, 2022
2+
----------------------
3+
* Bugfix - Fix networkx incompatable change by version pinning to 2.6.3 PR #1036 (#1035)
4+
* Add - Support for serializing numpy datetime64 types PR #1036 (#1022)
5+
* Update - Add traceback to default logging PR #1036
6+
17
0.13.6 -- Jun 13, 2022
28
----------------------
39
* Add - Config option to set threshold for when to stop using checksums for filepath stores. PR #1025

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pyparsing
44
ipython
55
pandas
66
tqdm
7-
networkx
7+
networkx<=2.6.3
88
pydot
99
minio>=7.0.0
1010
matplotlib

tests/test_blob.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datajoint as dj
2+
import timeit
23
import numpy as np
34
import uuid
45
from . import schema
@@ -149,6 +150,9 @@ def test_pack():
149150
x == unpack(pack(x)), "Numpy string array object did not pack/unpack correctly"
150151
)
151152

153+
x = np.datetime64("1998").astype("datetime64[us]")
154+
assert_true(x == unpack(pack(x)))
155+
152156

153157
def test_recarrays():
154158
x = np.array([(1.0, 2), (3.0, 4)], dtype=[("x", float), ("y", int)])
@@ -222,3 +226,25 @@ def test_insert_longblob():
222226
}
223227
(schema.Longblob & "id=1").delete()
224228
dj.blob.use_32bit_dims = False
229+
230+
231+
def test_datetime_serialization_speed():
232+
# If this fails that means for some reason deserializing/serializing
233+
# np arrays of np.datetime64 types is now slower than regular arrays of datetime64
234+
235+
optimized_exe_time = timeit.timeit(
236+
setup="myarr=pack(np.array([np.datetime64('2022-10-13 03:03:13') for _ in range(0, 10000)]))",
237+
stmt="unpack(myarr)",
238+
number=10,
239+
globals=globals(),
240+
)
241+
print(f"np time {optimized_exe_time}")
242+
baseline_exe_time = timeit.timeit(
243+
setup="myarr2=pack(np.array([datetime(2022,10,13,3,3,13) for _ in range (0, 10000)]))",
244+
stmt="unpack(myarr2)",
245+
number=10,
246+
globals=globals(),
247+
)
248+
print(f"python time {baseline_exe_time}")
249+
250+
assert optimized_exe_time * 1000 < baseline_exe_time

0 commit comments

Comments
 (0)