diff --git a/muspy/base.py b/muspy/base.py index 719b4fb..a22137a 100644 --- a/muspy/base.py +++ b/muspy/base.py @@ -17,18 +17,22 @@ Any, Callable, Iterable, + Iterator, List, Mapping, Optional, + Tuple, Type, TypeVar, Union, + cast, ) from .utils import yaml_dump __all__ = ["Base", "ComplexBase"] +T = TypeVar("T") BaseType = TypeVar("BaseType", bound="Base") ComplexBaseType = TypeVar("ComplexBaseType", bound="ComplexBase") @@ -513,8 +517,8 @@ class ComplexBase(Base): """Base class that supports advanced operations on list attributes. This class extend the Base class with advanced operations on list - attributes, including `append`, `remove_invalid`, `remove_duplicate` - and `sort`. + attributes, including `append`, `remove_invalid`, + `remove_duplicate` and `sort`. See Also -------- @@ -537,6 +541,35 @@ def __add__( ) return self.deepcopy().extend(other, deepcopy=True) + def _traverse_lists( + self, + attr: Optional[str], + recursive: bool + ) -> Iterator[Tuple[str, Type[T], List[T]]]: + """A generator yielding a tuple `(attr, attr_type, value)` + for all list attributes.""" + if attr is None: + for attribute in self._list_attributes: + yield from self._traverse_lists( # type: ignore + attribute, recursive) + return + if attr not in self._list_attributes: + raise TypeError("`attr` must be a list attribute.") + attr_type = self._attributes[attr] + + # Apply recursively using post-order traversal + # (first children, then self) + if ( + recursive + and isclass(attr_type) + and issubclass(attr_type, ComplexBase) + ): + for item in getattr(self, attr): + yield from item._traverse_lists( # type: ignore + attr=None, recursive=recursive) + + yield (attr, attr_type, getattr(self, attr)) + def _append(self, obj): for attr in self._list_attributes: attr_type = self._attributes[attr] @@ -606,33 +639,6 @@ def extend( self._append(copy.deepcopy(item) if deepcopy else item) return self - def _remove_invalid(self, attr: str, recursive: bool): - # Skip it if empty - if not getattr(self, attr): - return - - attr_type = self._attributes[attr] - is_class = isclass(attr_type) - is_base = is_class and issubclass(attr_type, Base) - is_complexbase = is_class and issubclass(attr_type, ComplexBase) - value = getattr(self, attr) - - # NOTE: The ordering mathers here. We first apply recursively - # and later check the currect object so that something that can - # be fixed in a lower level would not make the high-level object - # removed. - - # Apply recursively - if recursive and is_complexbase: - for item in value: - item.remove_invalid(recursive=recursive) - - # Replace the old list with a new list of only valid items - if is_base: - value[:] = [item for item in value if item.is_valid()] - else: - value[:] = [item for item in value if isinstance(item, attr_type)] - def remove_invalid( self: ComplexBaseType, attr: Optional[str] = None, @@ -652,43 +658,28 @@ def remove_invalid( Object itself. """ - if attr is None: - for attribute in self._list_attributes: - self._remove_invalid(attribute, recursive) - elif attr in self._list_attributes: - self._remove_invalid(attr, recursive) - else: - raise TypeError("`{}` must be a list attribute.") - return self + # NOTE: We depend on the fact that _traverse_lists uses + # post-order traversal. This way, we first apply recursively + # and later check the currect object so that something that + # can be fixed in a lower level would not make the high-level + # object to be removed. + + attr_type: Type + value: list + for _, attr_type, value in self._traverse_lists(attr, recursive): + # Skip it if empty + if not value: + continue - def _remove_duplicate(self, attr: str, recursive: bool): - # Skip it if empty - if not getattr(self, attr): - return - - attr_type = self._attributes[attr] - is_complexbase = isclass(attr_type) and issubclass( - attr_type, ComplexBase - ) - value = getattr(self, attr) - - # NOTE: The ordering mathers here. We first apply recursively - # and later check the currect object so that something that can - # be fixed in a lower level would not make the high-level object - # removed. - - # Apply recursively - if recursive and is_complexbase: - for item in value: - item.remove_duplicate(recursive=recursive) + if isclass(attr_type) and issubclass(attr_type, Base): + value[:] = [item for item in value # type: ignore + if (isinstance(item, attr_type) + and cast(Base, item).is_valid())] + else: + value[:] = [item for item in value # type: ignore + if isinstance(item, attr_type)] - # Replace the old list with a new list without duplicates - # TODO: Speed this up by grouping by time. - new_value = [] - for item in value: - if item not in new_value: - new_value.append(item) - value[:] = new_value + return self def remove_duplicate( self: ComplexBaseType, @@ -709,30 +700,17 @@ def remove_duplicate( Object itself. """ - if attr is None: - for attribute in self._list_attributes: - self._remove_duplicate(attribute, recursive) - elif attr in self._list_attributes: - self._remove_duplicate(attr, recursive) - else: - raise TypeError("`{}` must be a list attribute.") - return self + value: list + for _, _, value in self._traverse_lists(attr, recursive): + # Replace the old list with a new list without duplicates + # TODO: Speed this up by grouping by time. + new_value = [] + for item in value: # type: ignore + if item not in new_value: + new_value.append(item) + value[:] = new_value - def _sort(self, attr: str, recursive: bool): - # Skip it if empty - if not getattr(self, attr): - return - - # Sort the list - attr_type = self._attributes[attr] - if isclass(attr_type) and issubclass(attr_type, Base): - # pylint: disable=protected-access - if "time" in attr_type._attributes: - getattr(self, attr).sort(key=attrgetter("time")) - # Apply recursively - if recursive and issubclass(attr_type, ComplexBase): - for value in getattr(self, attr): - value.sort(recursive=recursive) + return self def sort( self: ComplexBaseType, @@ -753,11 +731,54 @@ def sort( Object itself. """ - if attr is None: - for attribute in self._list_attributes: - self._sort(attribute, recursive) - elif attr in self._list_attributes: - self._sort(attr, recursive) - else: - raise TypeError("`{}` must be a list attribute.") + attr_type: Type + value: list + for _, attr_type, value in self._traverse_lists(attr, recursive): + if value and "time" in getattr(attr_type, "_attributes"): + value.sort(key=attrgetter("time")) + return self + + def lists(self, recursive: bool = False) -> Iterator[Tuple[str, list]]: + """Return an iterator over the names and values of all list + attributes. + + Parameters + ---------- + recursive: bool + Whether to apply recursively to all list attributes. + Defaults to False. + + Returns + ------- + An iterator over `(attr_name, attr_value)` tuples. + """ + def generator(): + attr: str + value: list + for attr, _, value in self._traverse_lists(attr=None, + recursive=recursive): + yield attr, value + + return generator() + + def list_items(self, recursive: bool = False) -> Iterator: + """Return a flat iterator over the items of all list attributes. + + Parameters + ---------- + recursive: bool + Whether to apply recursively to all list attributes. + Defaults to False. + + Returns + ------- + A flat iterator over list items. + """ + def generator(): + value: list + for _, _, value in self._traverse_lists(attr=None, + recursive=recursive): + yield from value + + return generator() diff --git a/muspy/classes.py b/muspy/classes.py index 5087f9b..ba94492 100644 --- a/muspy/classes.py +++ b/muspy/classes.py @@ -618,6 +618,7 @@ def __init__( lyrics: Optional[List[Lyric]] = None, annotations: Optional[List[Annotation]] = None, ): + ComplexBase.__init__(self) self.program = program if program is not None else 0 self.is_drum = is_drum if program is not None else False self.name = name diff --git a/muspy/music.py b/muspy/music.py index aa093a7..b1caa67 100644 --- a/muspy/music.py +++ b/muspy/music.py @@ -139,6 +139,7 @@ def __init__( annotations: Optional[List[Annotation]] = None, tracks: Optional[List[Track]] = None, ): + ComplexBase.__init__(self) self.metadata = metadata if metadata is not None else Metadata() self.resolution = ( resolution if resolution is not None else DEFAULT_RESOLUTION