From 06667de81808cc779aeb523ef02b827b383be64e Mon Sep 17 00:00:00 2001 From: Labbeti Date: Fri, 26 Apr 2024 14:57:27 +0200 Subject: [PATCH] Fix: Internal typing in base class. --- src/aac_datasets/datasets/base.py | 9 ++++++--- src/aac_datasets/utils/type_checks.py | 14 +++++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/aac_datasets/datasets/base.py b/src/aac_datasets/datasets/base.py index e6b0765..0a19aa2 100644 --- a/src/aac_datasets/datasets/base.py +++ b/src/aac_datasets/datasets/base.py @@ -36,6 +36,8 @@ is_iterable_bool, is_iterable_int, is_iterable_str, + is_list_bool, + is_list_int, ) pylog = logging.getLogger(__name__) @@ -250,13 +252,14 @@ def at( if isinstance(index, Iterable): index = list(index) - if is_iterable_bool(index): + if is_list_bool(index): if len(index) != len(self): raise IndexError( f"The length of the mask ({len(index)}) does not match the length of the dataset ({len(self)})." ) index = [i for i, idx_i in enumerate(index) if idx_i] - elif __debug__ and not is_iterable_int(index): + + elif __debug__ and not is_list_int(index): raise TypeError( f"Invalid input type for index={index}. (expected Iterable[int], not Iterable[{index[0].__class__.__name__}])" ) @@ -387,7 +390,7 @@ def to_dict(self, load_online_values: bool = False) -> Dict[str, List[Any]]: def to_list(self, load_online_values: bool = False) -> List[ItemType]: raw_data = self.to_dict(load_online_values) - return dict_list_to_list_dict(raw_data, key_mode="same") + return dict_list_to_list_dict(raw_data, key_mode="same") # type: ignore # Magic methods @overload diff --git a/src/aac_datasets/utils/type_checks.py b/src/aac_datasets/utils/type_checks.py index 83ac86a..10e6e57 100644 --- a/src/aac_datasets/utils/type_checks.py +++ b/src/aac_datasets/utils/type_checks.py @@ -1,11 +1,15 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Any, Iterable +from typing import Any, Iterable, List from typing_extensions import TypeGuard +def is_iterable_bool(x: Any) -> TypeGuard[Iterable[bool]]: + return isinstance(x, Iterable) and all(isinstance(xi, bool) for xi in x) + + def is_iterable_int(x: Any) -> TypeGuard[Iterable[int]]: return isinstance(x, Iterable) and all(isinstance(xi, int) for xi in x) @@ -18,5 +22,9 @@ def is_iterable_str(x: Any, *, accept_str: bool) -> TypeGuard[Iterable[str]]: ) -def is_iterable_bool(x: Any) -> TypeGuard[Iterable[bool]]: - return isinstance(x, Iterable) and all(isinstance(xi, bool) for xi in x) +def is_list_bool(x: Any) -> TypeGuard[List[bool]]: + return isinstance(x, list) and all(isinstance(xi, bool) for xi in x) + + +def is_list_int(x: Any) -> TypeGuard[List[int]]: + return isinstance(x, list) and all(isinstance(xi, int) for xi in x)