Skip to content

Commit

Permalink
Mod: Update getitem typing for datasets classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed May 10, 2023
1 parent 3dd4eeb commit 5426a6b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
21 changes: 15 additions & 6 deletions src/aac_datasets/datasets/audiocaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from functools import lru_cache
from subprocess import CalledProcessError
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, overload

import torch
import torchaudio
Expand Down Expand Up @@ -49,7 +49,7 @@ class AudioCapsItem(TypedDict):
)


class AudioCaps(Dataset[AudioCapsItem]):
class AudioCaps(Dataset):
r"""Unofficial AudioCaps PyTorch dataset.
Subsets available are 'train', 'val' and 'test'.
Expand Down Expand Up @@ -309,10 +309,19 @@ def set_transform(
self._transform = transform

# Magic methods
def __getitem__(
self,
idx: Any,
) -> AudioCapsItem:
@overload
def __getitem__(self, idx: int) -> AudioCapsItem:
...

@overload
def __getitem__(self, idx: Union[Iterable[int], slice, None]) -> dict[str, list]:
...

@overload
def __getitem__(self, idx: Any) -> Any:
...

def __getitem__(self, idx: Any) -> Any:
if (
isinstance(idx, tuple)
and len(idx) == 2
Expand Down
21 changes: 15 additions & 6 deletions src/aac_datasets/datasets/clotho.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os.path as osp

from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, overload
from zipfile import ZipFile

import torchaudio
Expand Down Expand Up @@ -254,7 +254,7 @@ class ClothoItem(TypedDict):
)


class Clotho(Dataset[ClothoItem]):
class Clotho(Dataset):
r"""Unofficial Clotho PyTorch dataset.
Subsets available are 'train', 'val', 'eval', 'test' and 'analysis'.
Expand Down Expand Up @@ -509,10 +509,19 @@ def set_transform(
self._transform = transform

# Magic methods
def __getitem__(
self,
idx: Any,
) -> ClothoItem:
@overload
def __getitem__(self, idx: int) -> ClothoItem:
...

@overload
def __getitem__(self, idx: Union[Iterable[int], slice, None]) -> dict[str, list]:
...

@overload
def __getitem__(self, idx: Any) -> Any:
...

def __getitem__(self, idx: Any) -> Any:
if (
isinstance(idx, tuple)
and len(idx) == 2
Expand Down
21 changes: 15 additions & 6 deletions src/aac_datasets/datasets/macs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import zipfile

from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, overload

import torchaudio
import yaml
Expand Down Expand Up @@ -48,7 +48,7 @@ class MACSItem(TypedDict):
MACS_ALL_COLUMNS = tuple(MACSItem.__required_keys__ | MACSItem.__optional_keys__)


class MACS(Dataset[MACSItem]):
class MACS(Dataset):
r"""Unofficial MACS PyTorch dataset.
.. code-block:: text
Expand Down Expand Up @@ -277,10 +277,19 @@ def set_transform(
self._transform = transform

# Magic methods
def __getitem__(
self,
idx: Any,
) -> MACSItem:
@overload
def __getitem__(self, idx: int) -> MACSItem:
...

@overload
def __getitem__(self, idx: Union[Iterable[int], slice, None]) -> dict[str, list]:
...

@overload
def __getitem__(self, idx: Any) -> Any:
...

def __getitem__(self, idx: Any) -> Any:
if (
isinstance(idx, tuple)
and len(idx) == 2
Expand Down

0 comments on commit 5426a6b

Please sign in to comment.