Skip to content

Commit

Permalink
Fix remove_invalid and remove_duplicate (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
salu133445 committed Jan 13, 2021
1 parent 89af9f1 commit 8899576
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 23 deletions.
52 changes: 29 additions & 23 deletions muspy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,25 +612,26 @@ def _remove_invalid(self, attr: str, recursive: bool):
return

attr_type = self._attributes[attr]
value = getattr(self, attr)
is_class = isclass(attr_type)
is_base = is_class and issubclass(attr_type, Base)
is_complexbase = is_class and issubclass(attr_type, ComplexBase)
value = getattr(self, attr)

# NOTE: The ordering mathers here. We first apply recursively
# and later check to the currect object so that something that
# can be fixed in a lower level would not make the high-level
# object to be removed.
# and later check the currect object so that something that can
# be fixed in a lower level would not make the high-level object
# removed.

# Apply recursively
if recursive and is_class and issubclass(attr_type, ComplexBase):
for value in getattr(self, attr):
value.remove_invalid(recursive=recursive)
if recursive and is_complexbase:
for item in value:
item.remove_invalid(recursive=recursive)

# Replace the old list with a new list of only valid items
if is_class and issubclass(attr_type, Base):
new_value = [item for item in value if item.is_valid()]
if is_base:
value[:] = [item for item in value if item.is_valid()]
else:
new_value = [item for item in value if isinstance(item, attr_type)]
setattr(self, attr, new_value)
value[:] = [item for item in value if isinstance(item, attr_type)]

def remove_invalid(
self: ComplexBaseType,
Expand Down Expand Up @@ -665,24 +666,29 @@ def _remove_duplicate(self, attr: str, recursive: bool):
if not getattr(self, attr):
return

# Replace the old list with a new list without duplicates
# TODO: Speed this up by grouping by time.
attr_type = self._attributes[attr]
is_complexbase = isclass(attr_type) and issubclass(
attr_type, ComplexBase
)
value = getattr(self, attr)

# NOTE: The ordering mathers here. We first apply recursively
# and later check the currect object so that something that can
# be fixed in a lower level would not make the high-level object
# removed.

# Apply recursively
if recursive and is_complexbase:
for item in value:
item.remove_duplicate(recursive=recursive)

# Replace the old list with a new list without duplicates
# TODO: Speed this up by grouping by time.
new_value = []
for item in value:
if item not in new_value:
new_value.append(item)
setattr(self, attr, new_value)

# Apply recursively
if (
recursive
and isclass(attr_type)
and issubclass(attr_type, ComplexBase)
):
for value in getattr(self, attr):
value.remove_duplicate(recursive=recursive)
value[:] = new_value

def remove_duplicate(
self: ComplexBaseType,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ def test_remove_invalid():
assert len(track) == 1


def test_remove_invalid_recursive():
notes = [
Note(time=-1, duration=1, pitch=60),
Note(time=0, duration=1, pitch=60),
]
music = Music(tracks=[Track(notes=notes)])
music.remove_invalid()
assert len(music) == 1
assert len(music[0]) == 1


def test_remove_duplicate():
notes = [
Note(time=0, duration=1, pitch=60),
Expand All @@ -100,6 +111,17 @@ def test_remove_duplicate():
assert len(track) == 1


def test_remove_duplicate_recursive():
notes = [
Note(time=0, duration=1, pitch=60),
Note(time=0, duration=1, pitch=60),
]
music = Music(tracks=[Track(notes=notes)])
music.remove_duplicate()
assert len(music) == 1
assert len(music[0]) == 1


def test_sort_track():
notes = [
Note(time=2, pitch=64, duration=1),
Expand Down

0 comments on commit 8899576

Please sign in to comment.