diff --git a/src/aac_datasets/datasets/base.py b/src/aac_datasets/datasets/base.py index f601f49..c75b058 100644 --- a/src/aac_datasets/datasets/base.py +++ b/src/aac_datasets/datasets/base.py @@ -240,14 +240,14 @@ def at( if isinstance(index, Iterable): index = list(index) - if all(isinstance(idx_i, bool) for idx_i in index): + if is_iterable_bool(index): if len(index) != len(self): raise IndexError( f"The length of the mask ({len(index)}) does not match the length of the dataset ({len(self)})." ) index = [i for i, idx_i in enumerate(index) if idx_i] - elif not all(isinstance(idx_i, int) for idx_i in index): + elif not is_iterable_int(index): raise TypeError( f"Invalid input type for index={index}. (expected Iterable[int], not Iterable[{index[0].__class__.__name__}])" )