Skip to content

Commit

Permalink
Simplify our unpickling logic and add tests for Python 2 pickles
Browse files Browse the repository at this point in the history
  • Loading branch information
poodlewars committed Feb 3, 2025
1 parent 4a2fc47 commit c54bebe
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 17 deletions.
21 changes: 4 additions & 17 deletions python/arcticdb/version_store/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,11 +1050,11 @@ def _ext_hook(self, code, data):
# If stored in Python2 we want to use raw while unpacking.
# https://github.com/msgpack/msgpack-python/blob/master/msgpack/_unpacker.pyx#L230
data = unpackb(data, raw=True)
return Pickler.read(data, pickled_in_python2=True)
return Pickler.read(data)

if code == MsgPackSerialization.PY_PICKLE_3:
data = unpackb(data, raw=False)
return Pickler.read(data, pickled_in_python2=False)
return Pickler.read(data)

return ExtType(code, data)

Expand All @@ -1070,21 +1070,8 @@ def _msgpack_unpackb(self, buff, raw=False):

class Pickler(object):
@staticmethod
def read(data, pickled_in_python2=False):
if isinstance(data, str):
return pickle.loads(data.encode("ascii"), encoding="bytes")
elif isinstance(data, str):
if not pickled_in_python2:
# Use the default encoding for python2 pickled objects similar to what's being done for PY2.
return pickle.loads(data, encoding="bytes")

try:
# This tries normal pickle.loads first then falls back to special Pandas unpickling. Pandas unpickling
# handles Pandas 1 vs Pandas 2 API breaks better.
return pd.read_pickle(io.BytesIO(data))
except UnicodeDecodeError as exc:
log.debug("Failed decoding with ascii, using latin-1.")
return pickle.loads(data, encoding="latin-1")
def read(data):
return pd.read_pickle(io.BytesIO(data))

@staticmethod
def write(obj):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ def test_rt_df_with_small_meta(object_and_mem_and_lmdb_version_store):
assert meta == vit.metadata


class A:
def __init__(self, attrib):
self.attrib = attrib

def __eq__(self, other):
return self.attrib == other.attrib


def test_rt_df_with_custom_meta(object_and_mem_and_lmdb_version_store):
lib = object_and_mem_and_lmdb_version_store

df = DataFrame(data=["A", "B", "C"])
meta = {"a_key": A("bananabread")}
lib.write("pandas", df, metadata=meta)
vit = lib.read("pandas")
assert meta == vit.metadata


@pytest.mark.parametrize("log_level", ("error", "warn", "debug", "info", "ERROR", "eRror", "", None))
def test_pickled_metadata_warning(lmdb_version_store_v1, log_level):
import arcticdb.version_store._normalization as norm
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""How some of the Python 2 pickles in test_normalization.py were created.
Executed from a Python 2 env with msgpack 0.6.2
"""
from email import errors # arbitrary module with some custom types to pickle
import pickle
import msgpack
import sys

major_version = sys.version[0]


def custom_pack(obj):
# 102 is our extension code for pickled in Python 2
return msgpack.ExtType(102, msgpack.packb(pickle.dumps(obj)))


def msgpack_packb(obj):
return msgpack.packb(obj, use_bin_type=True, strict_types=True, default=custom_pack)


obj = errors.BoundaryError("bananas")
title = "py" + major_version + "_obj.bin"
with open(title, "wb") as f:
msgpack.dump(obj, f, default=custom_pack)

obj = {"dict_key": errors.BoundaryError("bananas")}
title = "py" + major_version + "_dict.bin"
with open(title, "wb") as f:
msgpack.dump(obj, f, default=custom_pack)

obj = "my_string"
title = "py" + major_version + "_str.bin"
with open(title, "wb") as f:
msgpack.dump(obj, f, default=custom_pack)

obj = b"my_bytes"
title = "py" + major_version + "_str_bytes.bin"
with open(title, "wb") as f:
msgpack.dump(obj, f, default=custom_pack)
108 changes: 108 additions & 0 deletions python/tests/unit/arcticdb/version_store/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
As of the Change Date specified in that file, in accordance with the Business Source License, use of this software will be governed by the Apache License, version 2.0.
"""
import datetime
from email import errors
import inspect
import itertools
import sys
Expand Down Expand Up @@ -92,6 +93,113 @@ def test_msg_pack_legacy_2():
assert data == loc_dt


def test_decode_python2_pickle_in_msgpack_dict():
"""See python2_pickles.py for the generation steps. This is the py2_dict.bin case.
This is to check that we can still deserialize pickles that were written with Python 2 correctly.
"""
norm = test_msgpack_normalizer
packed = b"\x81\xa8dict_key\xc7:f\xda\x007cemail.errors\nBoundaryError\np0\n(S'bananas'\np1\ntp2\nRp3\n."
data = norm._msgpack_unpackb(packed)
assert list(data.keys()) == ["dict_key"]
assert isinstance(data["dict_key"], errors.BoundaryError)
assert data["dict_key"].args[0] == "bananas"


def test_decode_python2_pickle_in_msgpack_obj():
"""See python2_pickles.py for the generation steps. This is the py2_obj.bin case.
This is to check that we can still deserialize pickles that were written with Python 2 correctly.
"""
norm = test_msgpack_normalizer
packed = b"\xc7:f\xda\x007cemail.errors\nBoundaryError\np0\n(S'bananas'\np1\ntp2\nRp3\n."
data = norm._msgpack_unpackb(packed)
assert isinstance(data, errors.BoundaryError)
assert data.args[0] == "bananas"


def test_decode_python2_str_in_msgpack():
"""See python2_pickles.py for the generation steps. This is the py2_str.bin case.
This is to check that we can still deserialize strings that were written with Python 2 correctly.
"""
norm = test_msgpack_normalizer
packed = b'\xa9my_string'
data = norm._msgpack_unpackb(packed)
assert data == "my_string"
assert isinstance(data, str)


def test_decode_python2_bytes_in_old_msgpack():
"""See python2_pickles.py for the generation steps. This is the py2_str_bytes.bin case.
This is to check that we can still deserialize bytes that were written with Python 2 correctly.
"""
norm = test_msgpack_normalizer
packed = b'\xa8my_bytes'
data = norm._msgpack_unpackb(packed)

# We claim it's `str` upon decoding because the `xa8` leading bytes tells us this is a fixed string type.
assert data == "my_bytes"
assert isinstance(data, str)


def test_decode_python2_bytes_in_newer_msgpack():
"""See python2_pickles.py for the generation steps. This is the py2_str_bytes.bin case.
This was written with msgpack 1.0.5 not 0.6.2 like the other examples. In this version, msgpack has
a dedicated type for bytes.
This is to check that we can still deserialize bytes that were written with Python 2 correctly.
"""
norm = test_msgpack_normalizer
packed = b'\xc4\x08my_bytes'
data = norm._msgpack_unpackb(packed)
assert data == b"my_bytes"
assert isinstance(data, bytes)


def test_decode_python3_pickle_in_msgpack_dict():
norm = test_msgpack_normalizer
obj = {"dict_key": errors.BoundaryError("bananas")}
packed = norm._msgpack_packb(obj)

data = norm._msgpack_unpackb(packed)
assert list(data.keys()) == ["dict_key"]
assert isinstance(data["dict_key"], errors.BoundaryError)
assert data["dict_key"].args[0] == "bananas"


def test_decode_python3_pickle_in_msgpack_obj():
norm = test_msgpack_normalizer
obj = errors.BoundaryError("bananas")
packed = norm._msgpack_packb(obj)

data = norm._msgpack_unpackb(packed)
assert isinstance(data, errors.BoundaryError)
assert data.args[0] == "bananas"


def test_decode_python3_pickle_in_msgpack_str():
norm = test_msgpack_normalizer
obj = "bananas"
packed = norm._msgpack_packb(obj)

data = norm._msgpack_unpackb(packed)
assert isinstance(data, str)
assert data == "bananas"


def test_decode_python3_pickle_in_msgpack_bytes():
norm = test_msgpack_normalizer
obj = b"bananas"
packed = norm._msgpack_packb(obj)

data = norm._msgpack_unpackb(packed)
assert isinstance(data, bytes)
assert data == b"bananas"


@param_dict("d", params)
def test_user_meta_and_msg_pack(d):
n = normalize_metadata(d)
Expand Down

0 comments on commit c54bebe

Please sign in to comment.