diff --git a/muspy/base.py b/muspy/base.py index 9151a1e..cee7bc3 100644 --- a/muspy/base.py +++ b/muspy/base.py @@ -528,7 +528,7 @@ class ComplexBase(Base): def __init__(self, **kwargs): Base.__init__(self, **kwargs) - self._flat = self._flat_generator() + self._flat = _GeneratorIterable(self._flat_generator) def __iadd__( self: ComplexBaseType, other: Union[ComplexBaseType, Iterable] @@ -755,3 +755,13 @@ def _flat_generator(self) -> Iterable: value: list for _, _, value in self._traverse_lists(attr=None, recursive=True): yield from value + + +class _GeneratorIterable: + """Turns a generator function into a reusable iterable.""" + + def __init__(self, generator_fn: Callable[[], Iterable]): + self._generator_fn = generator_fn + + def __iter__(self): + return self._generator_fn()