Skip to content

Commit

Permalink
Fixed bug in intersection of encoded events.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Mar 21, 2024
1 parent 8f927e3 commit f1b1fd0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/random_events/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.6'
__version__ = '2.0.7'
4 changes: 2 additions & 2 deletions src/random_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def intersection(self, other: EventType) -> EventType:

variables = set(self.keys()) | set(other.keys())
for variable in variables:
assignment1 = self.get(variable, variable.domain)
assignment2 = other.get(variable, variable.domain)
assignment1 = self.get(variable, variable.encoded_domain if isinstance(self, EncodedEvent) else variable.domain)
assignment2 = other.get(variable, variable.encoded_domain if isinstance(self, EncodedEvent) else variable.domain)
intersection = variable.intersection_of_assignments(assignment1, assignment2)
result[variable] = intersection

Expand Down
4 changes: 4 additions & 0 deletions src/random_events/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def assignment_from_json(self, data: Any) -> AssignmentType:
"""
raise NotImplementedError

@property
def encoded_domain(self):
return self.encode_many(self.domain)


class Continuous(Variable):
"""
Expand Down
9 changes: 9 additions & 0 deletions test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,15 @@ def test_serialization(self):
complement_ = ComplexEvent.from_json(json)
self.assertEqual(complement, complement_)

def test_intersection_symbol_and_real(self):
event = ComplexEvent([EncodedEvent({self.x: portion.closed(0, 1)})])
event2 = EncodedEvent({self.a: (0, )})
result = event & event2
self.assertEqual(len(result.events), 1)
event_ = result.events[0]
self.assertEqual(event_[self.x], portion.closed(0, 1))
self.assertEqual(event_[self.a], (0, ))


class PlottingTestCase(unittest.TestCase):
x: Continuous = Continuous("x")
Expand Down

0 comments on commit f1b1fd0

Please sign in to comment.