Skip to content

Commit

Permalink
Fix: Internal typing in base class.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 26, 2024
1 parent 30c5563 commit 06667de
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
9 changes: 6 additions & 3 deletions src/aac_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
is_iterable_bool,
is_iterable_int,
is_iterable_str,
is_list_bool,
is_list_int,
)

pylog = logging.getLogger(__name__)
Expand Down Expand Up @@ -250,13 +252,14 @@ def at(

if isinstance(index, Iterable):
index = list(index)
if is_iterable_bool(index):
if is_list_bool(index):
if len(index) != len(self):
raise IndexError(
f"The length of the mask ({len(index)}) does not match the length of the dataset ({len(self)})."
)
index = [i for i, idx_i in enumerate(index) if idx_i]
elif __debug__ and not is_iterable_int(index):

elif __debug__ and not is_list_int(index):
raise TypeError(
f"Invalid input type for index={index}. (expected Iterable[int], not Iterable[{index[0].__class__.__name__}])"
)
Expand Down Expand Up @@ -387,7 +390,7 @@ def to_dict(self, load_online_values: bool = False) -> Dict[str, List[Any]]:

def to_list(self, load_online_values: bool = False) -> List[ItemType]:
raw_data = self.to_dict(load_online_values)
return dict_list_to_list_dict(raw_data, key_mode="same")
return dict_list_to_list_dict(raw_data, key_mode="same") # type: ignore

# Magic methods
@overload
Expand Down
14 changes: 11 additions & 3 deletions src/aac_datasets/utils/type_checks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Any, Iterable
from typing import Any, Iterable, List

from typing_extensions import TypeGuard


def is_iterable_bool(x: Any) -> TypeGuard[Iterable[bool]]:
return isinstance(x, Iterable) and all(isinstance(xi, bool) for xi in x)


def is_iterable_int(x: Any) -> TypeGuard[Iterable[int]]:
return isinstance(x, Iterable) and all(isinstance(xi, int) for xi in x)

Expand All @@ -18,5 +22,9 @@ def is_iterable_str(x: Any, *, accept_str: bool) -> TypeGuard[Iterable[str]]:
)


def is_iterable_bool(x: Any) -> TypeGuard[Iterable[bool]]:
return isinstance(x, Iterable) and all(isinstance(xi, bool) for xi in x)
def is_list_bool(x: Any) -> TypeGuard[List[bool]]:
return isinstance(x, list) and all(isinstance(xi, bool) for xi in x)


def is_list_int(x: Any) -> TypeGuard[List[int]]:
return isinstance(x, list) and all(isinstance(xi, int) for xi in x)

0 comments on commit 06667de

Please sign in to comment.