Skip to content

Commit

Permalink
Fix serialization/deserialization of weakrefs (#61)
Browse files Browse the repository at this point in the history
* add __getstate__ __setstate__

* override pickle and eq methods for XMLAnnotation

* relink refs only in OME.__setstate__

* remove weakref import

* revert import change
  • Loading branch information
tlambert03 authored Dec 24, 2020
1 parent 8c73195 commit 8528da0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
29 changes: 29 additions & 0 deletions src/ome_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,17 @@ class UUID:
""",
body="""
def __post_init_post_parse__(self: Any, *args: Any) -> None:
self._link_refs()
def _link_refs(self):
ids = util.collect_ids(self)
for ref in util.collect_references(self):
ref.ref_ = weakref.ref(ids[ref.id])
def __setstate__(self: Any, state: Dict[str, Any]) -> None:
'''Support unpickle of our weakref references.'''
self.__dict__.update(state)
self._link_refs()
""",
),
"Reference": ClassOverride(
Expand All @@ -296,6 +304,27 @@ def ref(self) -> Any:
return self.ref_()
""",
),
"XMLAnnotation": ClassOverride(
body='''
def __getstate__(self: Any):
"""Support pickle of our weakref references."""
from ome_types.schema import ElementTree
d = self.__dict__.copy()
d["value"] = ElementTree.tostring(d.pop("value")).strip()
return d
def __setstate__(self: Any, state) -> None:
"""Support unpickle of our weakref references."""
from ome_types.schema import ElementTree
self.__dict__.update(state)
self.value = ElementTree.fromstring(self.value)
def __eq__(self, o: "XMLAnnotation") -> bool:
return self.__getstate__() == o.__getstate__()
'''
),
"BinData": ClassOverride(base_type="object", fields="value: str"),
"Map": ClassOverride(fields_suppress={"K"}),
"M": ClassOverride(base_type="object", fields="value: str"),
Expand Down
14 changes: 13 additions & 1 deletion src/ome_types/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime
from enum import Enum
from textwrap import indent
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Type, Union

import pint
from pydantic import validator
Expand Down Expand Up @@ -158,6 +158,16 @@ def new_repr(self: Any) -> str:
setattr(_cls, "__repr__", new_repr)


def __getstate__(self: Any) -> Dict[str, Any]:
"""Support pickle of our weakref references."""
# don't do copy unless necessary
if "ref_" in self.__dict__:
d = self.__dict__.copy()
del d["ref_"] # remove weakref
return d
return self.__dict__


def ome_dataclass(
_cls: Optional[Type[Any]] = None,
*,
Expand All @@ -178,6 +188,8 @@ def wrap(cls: Type[Any]) -> DataclassType:
if getattr(cls, "id", None) is AUTO_SEQUENCE:
setattr(cls, "validate_id", validate_id)
modify_post_init(cls)
if not hasattr(cls, "__getstate__"):
cls.__getstate__ = __getstate__
add_quantities(cls)
if not repr:
modify_repr(cls)
Expand Down
13 changes: 13 additions & 0 deletions testing/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
import re
from pathlib import Path
from xml.dom import minidom
Expand Down Expand Up @@ -119,6 +120,18 @@ def canonicalize(xml, strip_empty):
assert ours == original


@pytest.mark.parametrize("xml", xml_read, ids=true_stem)
def test_serialization(xml):
"""Test pickle serialization and reserialization."""
if true_stem(xml) in SHOULD_RAISE_READ:
pytest.skip("Can't pickle unreadable xml")

ome = from_xml(xml)
serialized = pickle.dumps(ome)
deserialized = pickle.loads(serialized)
assert ome == deserialized


def test_no_id():
"""Test that ids are optional, and auto-increment."""
i = model.Instrument(id=20)
Expand Down

0 comments on commit 8528da0

Please sign in to comment.