Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _apply_list_op, each, map, filter to ComplexBase #47

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 88 additions & 94 deletions muspy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@
Any,
Callable,
Iterable,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)

from .utils import yaml_dump

__all__ = ["Base", "ComplexBase"]

T = TypeVar("T")
BaseType = TypeVar("BaseType", bound="Base")
ComplexBaseType = TypeVar("ComplexBaseType", bound="ComplexBase")

Expand Down Expand Up @@ -513,15 +517,19 @@ class ComplexBase(Base):
"""Base class that supports advanced operations on list attributes.

This class extend the Base class with advanced operations on list
attributes, including `append`, `remove_invalid`, `remove_duplicate`
and `sort`.
attributes, including `append`, `remove_invalid`,
`remove_duplicate` and `sort`.

See Also
--------
:class:`muspy.Base` : Base class for MusPy classes.

"""

def __init__(self, **kwargs):
Base.__init__(self, **kwargs)
self._flat = self._flat_generator()
cifkao marked this conversation as resolved.
Show resolved Hide resolved

def __iadd__(
self: ComplexBaseType, other: Union[ComplexBaseType, Iterable]
) -> ComplexBaseType:
Expand All @@ -537,6 +545,35 @@ def __add__(
)
return self.deepcopy().extend(other, deepcopy=True)

def _traverse_lists(
self,
attr: Optional[str],
recursive: bool
) -> Iterator[Tuple[str, Type[T], List[T]]]:
"""A generator yielding a tuple `(attr, attr_type, value)`
for all list attributes."""
if attr is None:
for attribute in self._list_attributes:
yield from self._traverse_lists( # type: ignore
attribute, recursive)
return
if attr not in self._list_attributes:
raise TypeError("`attr` must be a list attribute.")
attr_type = self._attributes[attr]

# Apply recursively using post-order traversal
# (first children, then self)
if (
recursive
and isclass(attr_type)
and issubclass(attr_type, ComplexBase)
):
for item in getattr(self, attr):
yield from item._traverse_lists( # type: ignore
attr=None, recursive=recursive)

yield (attr, attr_type, getattr(self, attr))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A return signature of either the value (like dict.values) or a tuple of the key and the value(attr, value) (like dict.items) would be more straightforward. Maybe we can make them two separate methods.

But if we do this, we can no longer make remove_invalid, remove_duplicate and sort work using this, right?

Or we could keep this unchanged and have two public wrappers lists and lists_with_names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would definitely keep this and internal function.


def _append(self, obj):
for attr in self._list_attributes:
attr_type = self._attributes[attr]
Expand Down Expand Up @@ -606,33 +643,6 @@ def extend(
self._append(copy.deepcopy(item) if deepcopy else item)
return self

def _remove_invalid(self, attr: str, recursive: bool):
# Skip it if empty
if not getattr(self, attr):
return

attr_type = self._attributes[attr]
is_class = isclass(attr_type)
is_base = is_class and issubclass(attr_type, Base)
is_complexbase = is_class and issubclass(attr_type, ComplexBase)
value = getattr(self, attr)

# NOTE: The ordering mathers here. We first apply recursively
# and later check the currect object so that something that can
# be fixed in a lower level would not make the high-level object
# removed.

# Apply recursively
if recursive and is_complexbase:
for item in value:
item.remove_invalid(recursive=recursive)

# Replace the old list with a new list of only valid items
if is_base:
value[:] = [item for item in value if item.is_valid()]
else:
value[:] = [item for item in value if isinstance(item, attr_type)]

def remove_invalid(
self: ComplexBaseType,
attr: Optional[str] = None,
Expand All @@ -652,43 +662,28 @@ def remove_invalid(
Object itself.

"""
if attr is None:
for attribute in self._list_attributes:
self._remove_invalid(attribute, recursive)
elif attr in self._list_attributes:
self._remove_invalid(attr, recursive)
else:
raise TypeError("`{}` must be a list attribute.")
return self

def _remove_duplicate(self, attr: str, recursive: bool):
# Skip it if empty
if not getattr(self, attr):
return
# NOTE: We depend on the fact that _traverse_lists uses
# post-order traversal. This way, we first apply recursively
# and later check the currect object so that something that
# can be fixed in a lower level would not make the high-level
# object to be removed.

attr_type: Type
value: list
for _, attr_type, value in self._traverse_lists(attr, recursive):
# Skip it if empty
if not value:
continue

attr_type = self._attributes[attr]
is_complexbase = isclass(attr_type) and issubclass(
attr_type, ComplexBase
)
value = getattr(self, attr)

# NOTE: The ordering mathers here. We first apply recursively
# and later check the currect object so that something that can
# be fixed in a lower level would not make the high-level object
# removed.

# Apply recursively
if recursive and is_complexbase:
for item in value:
item.remove_duplicate(recursive=recursive)
if isclass(attr_type) and issubclass(attr_type, Base):
value[:] = [item for item in value # type: ignore
if (isinstance(item, attr_type)
and cast(Base, item).is_valid())]
else:
value[:] = [item for item in value # type: ignore
if isinstance(item, attr_type)]
Comment on lines +674 to +680
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type checking can be simplified if we change the return signature into Iterator[Tuple[str, List]].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there must be a bug in mypy, it was complaining about things whose type is actually known...


# Replace the old list with a new list without duplicates
# TODO: Speed this up by grouping by time.
new_value = []
for item in value:
if item not in new_value:
new_value.append(item)
value[:] = new_value
return self

def remove_duplicate(
self: ComplexBaseType,
Expand All @@ -709,30 +704,17 @@ def remove_duplicate(
Object itself.

"""
if attr is None:
for attribute in self._list_attributes:
self._remove_duplicate(attribute, recursive)
elif attr in self._list_attributes:
self._remove_duplicate(attr, recursive)
else:
raise TypeError("`{}` must be a list attribute.")
return self

def _sort(self, attr: str, recursive: bool):
# Skip it if empty
if not getattr(self, attr):
return
value: list
for _, _, value in self._traverse_lists(attr, recursive):
# Replace the old list with a new list without duplicates
# TODO: Speed this up by grouping by time.
new_value = []
for item in value: # type: ignore
if item not in new_value:
new_value.append(item)
value[:] = new_value

# Sort the list
attr_type = self._attributes[attr]
if isclass(attr_type) and issubclass(attr_type, Base):
# pylint: disable=protected-access
if "time" in attr_type._attributes:
getattr(self, attr).sort(key=attrgetter("time"))
# Apply recursively
if recursive and issubclass(attr_type, ComplexBase):
for value in getattr(self, attr):
value.sort(recursive=recursive)
return self

def sort(
self: ComplexBaseType,
Expand All @@ -753,11 +735,23 @@ def sort(
Object itself.

"""
if attr is None:
for attribute in self._list_attributes:
self._sort(attribute, recursive)
elif attr in self._list_attributes:
self._sort(attr, recursive)
else:
raise TypeError("`{}` must be a list attribute.")
attr_type: Type
value: list
for _, attr_type, value in self._traverse_lists(attr, recursive):
if value and "time" in getattr(attr_type, "_attributes"):
value.sort(key=attrgetter("time"))

return self

@property
def flat(self) -> Iterable:
"""A flat representation of this object. Iterating over it
yields all items in all list attributes inside this object
(recursively). Non-list attributes are not included.
"""
return self._flat

def _flat_generator(self) -> Iterable:
value: list
for _, _, value in self._traverse_lists(attr=None, recursive=True):
yield from value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be problematic as it will keep reusing the generator created at __init__. We probably don't want this behavior. Also, I am not a fan of having flat as a property. Could we simply make it Music.lists(), which is rather straightforward from its name except that it's recursive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I didn't realize that.

How about this? flat is now reusable, and it's always the same object, so it's not a problem to have it as a property.

If you still insist on it being a function, why not call it e.g. walk_lists?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want it to behave more like dict.items() or dict.values() rather than numpy.ndarray.flat or music21.stream.Stream.flat. It's quite common in Python to have a function that returns an iterator.

Copy link
Contributor Author

@cifkao cifkao Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Is walk_lists() OK then? The name should give a hint that the return value is good only for (one) iteration.

BTW dict.items() and dict.values() also return objects that support more than just iteration (e.g. len()).

Copy link
Contributor Author

@cifkao cifkao Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, lists() suggests that it would iterate over lists. But my flat iterates over list items, not lists themselves. So we need a better name. I think flat() (just changing the property to a function) or walk_lists() is OK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it list_items() and added list() to iterate over the actual lists.

1 change: 1 addition & 0 deletions muspy/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def __init__(
lyrics: Optional[List[Lyric]] = None,
annotations: Optional[List[Annotation]] = None,
):
ComplexBase.__init__(self)
self.program = program if program is not None else 0
self.is_drum = is_drum if program is not None else False
self.name = name
Expand Down
1 change: 1 addition & 0 deletions muspy/music.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
annotations: Optional[List[Annotation]] = None,
tracks: Optional[List[Track]] = None,
):
ComplexBase.__init__(self)
self.metadata = metadata if metadata is not None else Metadata()
self.resolution = (
resolution if resolution is not None else DEFAULT_RESOLUTION
Expand Down