Skip to content

Commit

Permalink
Mod: Update is_index type guard to avoid complex tensors and tensors …
Browse files Browse the repository at this point in the history
…with more than 1 dim.
  • Loading branch information
Labbeti committed Jun 14, 2024
1 parent 5d889d2 commit 9f438c0
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/aac_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
IndexType = Union[int, Iterable[int], Iterable[bool], Tensor, slice, None]
ColumnType = Union[str, Iterable[str], None]

_IDX_TYPES = ("int", "Iterable[int]", "Iterable[bool]", "Tensor", "slice", "None")
_INDEX_TYPES = ("int", "Iterable[int]", "Iterable[bool]", "Tensor", "slice", "None")


def _is_index(index: Any) -> TypeGuard[IndexType]:
Expand All @@ -56,7 +56,12 @@ def _is_index(index: Any) -> TypeGuard[IndexType]:
or is_iterable_bool(index)
or isinstance(index, slice)
or index is None
or (isinstance(index, Tensor) and not index.is_floating_point())
or (
isinstance(index, Tensor)
and not index.is_floating_point()
and not index.is_complex()
and index.ndim in (0, 1)
)
)


Expand Down Expand Up @@ -276,7 +281,7 @@ def at(

if __debug__ and not isinstance(index, int):
raise TypeError(
f"Invalid argument type {type(index)}. (expected one of {_IDX_TYPES})"
f"Invalid argument type {type(index)}. (expected one of {_INDEX_TYPES})"
)

return self._load_online_value(column, index)
Expand Down

0 comments on commit 9f438c0

Please sign in to comment.