Skip to content

Commit

Permalink
Added serialization of events
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 21, 2024
1 parent 1403f74 commit 8f927e3
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.5'
__version__ = '2.0.6'
32 changes: 29 additions & 3 deletions src/random_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing_extensions import Set, Union, Any, TYPE_CHECKING, Iterable, List, Self, Dict, Tuple

from .variables import Variable, Continuous, Discrete
from .utils import SubclassJSONSerializer


# Type hinting for Python 3.7 to 3.9
Expand Down Expand Up @@ -109,7 +110,7 @@ def is_empty(self) -> bool:
raise NotImplementedError


class Event(SupportsSetOperations, EventMapType):
class Event(SupportsSetOperations, EventMapType, SubclassJSONSerializer):
"""
A map of variables to values of their respective domains.
"""
Expand Down Expand Up @@ -434,6 +435,21 @@ def marginal_event(self, variables: Iterable[Variable]) -> Self:
"""
return self.__class__({variable: self[variable] for variable in variables if variable in self})

def to_json(self) -> Dict[str, Any]:
result = super().to_json()
event = [(variable.to_json(), variable.assignment_to_json(assignment)) for variable, assignment in self.items()]
result["event"] = event
return result

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
result = cls()
for variable_json, assignment_json in data["event"]:
variable = Variable.from_json(variable_json)
assignment = variable.assignment_from_json(assignment_json)
result[variable] = assignment
return result


class EncodedEvent(Event):
"""
Expand Down Expand Up @@ -483,8 +499,7 @@ def encode(self) -> Self:
return self.__copy__()



class ComplexEvent(SupportsSetOperations):
class ComplexEvent(SupportsSetOperations, SubclassJSONSerializer):
"""
A complex event is a set of mutually exclusive events.
"""
Expand Down Expand Up @@ -708,5 +723,16 @@ def merge_if_one_dimensional(self) -> Self:
value = variable.union_of_assignments(value, event[variable])
return ComplexEvent([Event({variable: value})])

def to_json(self) -> Dict[str, Any]:
result = super().to_json()
events = [event.to_json() for event in self.events]
result["events"] = events
return result

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
events = [Event.from_json(event) for event in data["events"]]
return cls(events)


EventType = Union[Event, EncodedEvent, ComplexEvent]
40 changes: 40 additions & 0 deletions src/random_events/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing_extensions import Dict, Any, Self


def get_full_class_name(cls):
"""
Returns the full name of a class, including the module name.
Expand All @@ -14,3 +17,40 @@ def recursive_subclasses(cls):
:return: A list of the classes subclasses.
"""
return cls.__subclasses__() + [g for s in cls.__subclasses__() for g in recursive_subclasses(s)]


class SubclassJSONSerializer:
"""
Class for automatic (de)serialization of subclasses.
Classes that inherit from this class can be serialized and deserialized automatically by calling this classes
'from_json' method.
"""

def to_json(self) -> Dict[str, Any]:
return {"type": get_full_class_name(self.__class__)}

@classmethod
def _from_json(cls, data: Dict[str, Any]) -> Self:
"""
Create a variable from a json dict.
This method is called from the from_json method after the correct subclass is determined and should be
overwritten by the respective subclass.
:param data: The json dict
:return: The deserialized object
"""
raise NotImplementedError()

@classmethod
def from_json(cls, data: Dict[str, Any]) -> Self:
"""
Create the correct instanceof the subclass from a json dict.
:param data: The json dict
:return: The correct instance of the subclass
"""
for subclass in recursive_subclasses(SubclassJSONSerializer):
if get_full_class_name(subclass) == data["type"]:
return subclass._from_json(data)

raise ValueError("Unknown type {}".format(data["type"]))
40 changes: 25 additions & 15 deletions src/random_events/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
AssignmentType = Union[portion.Interval, Tuple]


class Variable:
class Variable(utils.SubclassJSONSerializer):
"""
Abstract base class for all variables.
"""
Expand Down Expand Up @@ -101,20 +101,6 @@ def _from_json(cls, data: Dict[str, Any]) -> 'Variable':
"""
return cls(name=data["name"], domain=data["domain"])

@classmethod
def from_json(cls, data: Dict[str, Any]) -> 'Variable':
"""
Create the correct instanceof the subclass from a json dict.
:param data: The json dict
:return: The correct instance of the subclass
"""
for subclass in utils.recursive_subclasses(Variable):
if utils.get_full_class_name(subclass) == data["type"]:
return subclass._from_json(data)

raise ValueError("Unknown type for variable. Type is {}".format(data["type"]))

def complement_of_assignment(self, assignment: AssignmentType, encoded: bool = False) -> AssignmentType:
"""
Returns the complement of the assignment for the variable.
Expand Down Expand Up @@ -153,6 +139,18 @@ def union_of_assignments(assignment1: AssignmentType,
"""
raise NotImplementedError

def assignment_to_json(self, assignment: AssignmentType) -> Any:
"""
Convert an assignment to a json serializable object.
"""
raise NotImplementedError

def assignment_from_json(self, data: Any) -> AssignmentType:
"""
Convert an assignment from a json serializable object.
"""
raise NotImplementedError


class Continuous(Variable):
"""
Expand Down Expand Up @@ -187,6 +185,12 @@ def union_of_assignments(assignment1: portion.Interval,
encoded: bool = False) -> portion.Interval:
return assignment1 | assignment2

def assignment_to_json(self, assignment: portion.Interval) -> Any:
return portion.to_data(assignment)

def assignment_from_json(self, data: Any) -> portion.Interval:
return portion.from_data(data)


class Discrete(Variable):
"""
Expand Down Expand Up @@ -252,6 +256,12 @@ def union_of_assignments(assignment1: Tuple,
encoded: bool = False) -> Tuple:
return tuple(sorted(set(assignment1) | set(assignment2)))

def assignment_to_json(self, assignment: Tuple) -> Tuple:
return assignment

def assignment_from_json(self, data: Any) -> AssignmentType:
return tuple(data)


class Symbolic(Discrete):
"""
Expand Down
27 changes: 27 additions & 0 deletions test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ def test_raises_on_operation_with_different_types(self):
with self.assertRaises(TypeError):
self.event - self.event.encode()

def test_serialization(self):
json = self.event.to_json()
event = Event.from_json(json)
self.assertEqual(event, self.event)

def test_serialization_with_complex_interval(self):
event = Event({self.real: portion.closed(0, 1) | portion.closed(2, 3)})
json = event.to_json()
event_ = Event.from_json(json)
self.assertEqual(event_, event)


class EncodedEventTestCase(unittest.TestCase):

Expand Down Expand Up @@ -244,6 +255,16 @@ def test_intersection_with_empty(self):
self.assertIn(self.integer, intersection.keys())
self.assertTrue(intersection.is_empty())

def test_serialization(self):
event = EncodedEvent()
event[self.integer] = (1, 2)
event[self.symbol] = {1, 0}
event[self.real] = portion.open(0, 1)

json = event.to_json()
event_ = EncodedEvent.from_json(json)
self.assertEqual(event, event_)


class ComplexEventTestCase(unittest.TestCase):

Expand Down Expand Up @@ -393,6 +414,12 @@ def test_merge_if_1d(self):
self.assertEqual(len(merged.events), 1)
self.assertEqual(merged.events[0][self.x], portion.closed(0, 1) | portion.closed(3, 4))

def test_serialization(self):
event = Event({self.x: portion.closed(0, 1), self.y: portion.closed(0, 1)})
complement = event.complement()
json = complement.to_json()
complement_ = ComplexEvent.from_json(json)
self.assertEqual(complement, complement_)


class PlottingTestCase(unittest.TestCase):
Expand Down

0 comments on commit 8f927e3

Please sign in to comment.