diff --git a/muspy/base.py b/muspy/base.py index 42313cf4..719b4fb9 100644 --- a/muspy/base.py +++ b/muspy/base.py @@ -612,25 +612,26 @@ def _remove_invalid(self, attr: str, recursive: bool): return attr_type = self._attributes[attr] - value = getattr(self, attr) is_class = isclass(attr_type) + is_base = is_class and issubclass(attr_type, Base) + is_complexbase = is_class and issubclass(attr_type, ComplexBase) + value = getattr(self, attr) # NOTE: The ordering mathers here. We first apply recursively - # and later check to the currect object so that something that - # can be fixed in a lower level would not make the high-level - # object to be removed. + # and later check the currect object so that something that can + # be fixed in a lower level would not make the high-level object + # removed. # Apply recursively - if recursive and is_class and issubclass(attr_type, ComplexBase): - for value in getattr(self, attr): - value.remove_invalid(recursive=recursive) + if recursive and is_complexbase: + for item in value: + item.remove_invalid(recursive=recursive) # Replace the old list with a new list of only valid items - if is_class and issubclass(attr_type, Base): - new_value = [item for item in value if item.is_valid()] + if is_base: + value[:] = [item for item in value if item.is_valid()] else: - new_value = [item for item in value if isinstance(item, attr_type)] - setattr(self, attr, new_value) + value[:] = [item for item in value if isinstance(item, attr_type)] def remove_invalid( self: ComplexBaseType, @@ -665,24 +666,29 @@ def _remove_duplicate(self, attr: str, recursive: bool): if not getattr(self, attr): return - # Replace the old list with a new list without duplicates - # TODO: Speed this up by grouping by time. attr_type = self._attributes[attr] + is_complexbase = isclass(attr_type) and issubclass( + attr_type, ComplexBase + ) value = getattr(self, attr) + + # NOTE: The ordering mathers here. We first apply recursively + # and later check the currect object so that something that can + # be fixed in a lower level would not make the high-level object + # removed. + + # Apply recursively + if recursive and is_complexbase: + for item in value: + item.remove_duplicate(recursive=recursive) + + # Replace the old list with a new list without duplicates + # TODO: Speed this up by grouping by time. new_value = [] for item in value: if item not in new_value: new_value.append(item) - setattr(self, attr, new_value) - - # Apply recursively - if ( - recursive - and isclass(attr_type) - and issubclass(attr_type, ComplexBase) - ): - for value in getattr(self, attr): - value.remove_duplicate(recursive=recursive) + value[:] = new_value def remove_duplicate( self: ComplexBaseType, diff --git a/tests/test_classes.py b/tests/test_classes.py index edae5342..27617757 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -90,6 +90,17 @@ def test_remove_invalid(): assert len(track) == 1 +def test_remove_invalid_recursive(): + notes = [ + Note(time=-1, duration=1, pitch=60), + Note(time=0, duration=1, pitch=60), + ] + music = Music(tracks=[Track(notes=notes)]) + music.remove_invalid() + assert len(music) == 1 + assert len(music[0]) == 1 + + def test_remove_duplicate(): notes = [ Note(time=0, duration=1, pitch=60), @@ -100,6 +111,17 @@ def test_remove_duplicate(): assert len(track) == 1 +def test_remove_duplicate_recursive(): + notes = [ + Note(time=0, duration=1, pitch=60), + Note(time=0, duration=1, pitch=60), + ] + music = Music(tracks=[Track(notes=notes)]) + music.remove_duplicate() + assert len(music) == 1 + assert len(music[0]) == 1 + + def test_sort_track(): notes = [ Note(time=2, pitch=64, duration=1),