Skip to content

Commit

Permalink
Mod: Update transform typing with item typeddict, dataloader example …
Browse files Browse the repository at this point in the history
…in notebook and version in init.
  • Loading branch information
Labbeti committed Feb 12, 2024
1 parent c276294 commit 1fbbd28
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 35 deletions.
8 changes: 4 additions & 4 deletions examples/dataloader.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"aac-datasets version: 0.5.0\n"
"aac-datasets version: 0.5.1\n"
]
}
],
Expand Down Expand Up @@ -108,7 +108,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -123,7 +123,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -188,7 +188,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.11"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
3 changes: 1 addition & 2 deletions src/aac_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__maintainer__ = "Etienne Labbé (Labbeti)"
__name__ = "aac-datasets"
__status__ = "Development"
__version__ = "0.5.0"
__version__ = "0.5.1"


from .datasets.audiocaps import AudioCaps
Expand All @@ -28,7 +28,6 @@
set_default_ytdlp_path,
)


__all__ = [
"AudioCaps",
"Clotho",
Expand Down
2 changes: 1 addition & 1 deletion src/aac_datasets/datasets/audiocaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
root: Union[str, Path, None] = None,
subset: str = AudioCapsCard.DEFAULT_SUBSET,
download: bool = False,
transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
transform: Optional[Callable[[AudioCapsItem], Any]] = None,
verbose: int = 0,
force_download: bool = False,
verify_files: bool = False,
Expand Down
17 changes: 10 additions & 7 deletions src/aac_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torchaudio
import tqdm
from typing_extensions import TypedDict

try:
# To support torchaudio >= 2.1.0
Expand All @@ -33,7 +32,7 @@
pylog = logging.getLogger(__name__)


ItemType = TypeVar("ItemType", bound=TypedDict, covariant=True)
ItemType = TypeVar("ItemType", covariant=True)


class AACDataset(Generic[ItemType], Dataset[ItemType]):
Expand All @@ -43,7 +42,7 @@ class AACDataset(Generic[ItemType], Dataset[ItemType]):
def __init__(
self,
raw_data: Optional[Dict[str, List[Any]]] = None,
transform: Optional[Callable] = None,
transform: Optional[Callable[[ItemType], Any]] = None,
column_names: Optional[Iterable[str]] = None,
flat_captions: bool = False,
sr: Optional[int] = None,
Expand Down Expand Up @@ -146,7 +145,7 @@ def column_names(
self._columns = columns

@transform.setter
def transform(self, transform: Optional[Callable]) -> None:
def transform(self, transform: Optional[Callable[[ItemType], Any]]) -> None:
self._transform = transform

# Public methods
Expand All @@ -158,6 +157,10 @@ def at(self, index: int) -> ItemType:
def at(self, index: Union[Iterable[int], slice, None], column: str) -> List:
...

@overload
def at(self, index: Union[Iterable[int], slice, None]) -> Dict[str, List]:
...

@overload
def at(
self,
Expand Down Expand Up @@ -376,10 +379,10 @@ def __getitem__(self, index: Any) -> Any:
item = self.at(index, column)
if (
isinstance(index, int)
and (column is None or column == self._columns)
and self._transform is not None
and (column is None or set(column) == set(self._columns))
):
item = self._transform(item)
item = self._transform(item) # type: ignore
return item

def __len__(self) -> int:
Expand All @@ -394,7 +397,7 @@ def __len__(self) -> int:
def __repr__(self) -> str:
info = {
"size": len(self),
"num_columns": len(self.column_names),
"num_columns": self.num_columns,
}
repr_str = ", ".join(f"{k}={v}" for k, v in info.items())
return f"{self.__class__.__name__}({repr_str})"
Expand Down
18 changes: 5 additions & 13 deletions src/aac_datasets/datasets/clotho.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,21 @@

import logging
import os.path as osp

from pathlib import Path
from typing import (
Callable,
ClassVar,
List,
Optional,
Union,
)
from typing import Any, Callable, ClassVar, List, Optional, Union

from torch import Tensor
from typing_extensions import TypedDict, NotRequired
from typing_extensions import NotRequired, TypedDict

from aac_datasets.datasets.base import AACDataset
from aac_datasets.datasets.functional.clotho import (
ClothoCard,
load_clotho_dataset,
download_clotho_dataset,
_get_audio_subset_dpath,
download_clotho_dataset,
load_clotho_dataset,
)
from aac_datasets.utils.globals import _get_root


pylog = logging.getLogger(__name__)


Expand Down Expand Up @@ -108,7 +100,7 @@ def __init__(
root: Union[str, Path, None] = None,
subset: str = ClothoCard.DEFAULT_SUBSET,
download: bool = False,
transform: Optional[Callable] = None,
transform: Optional[Callable[[ClothoItem], Any]] = None,
verbose: int = 0,
force_download: bool = False,
verify_files: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions src/aac_datasets/datasets/macs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os.path as osp
from pathlib import Path
from typing import Callable, ClassVar, Dict, List, Optional, Union
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union

from torch import Tensor
from typing_extensions import TypedDict
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
root: Union[str, Path, None] = None,
subset: str = MACSCard.DEFAULT_SUBSET,
download: bool = False,
transform: Optional[Callable] = None,
transform: Optional[Callable[[MACSItem], Any]] = None,
verbose: int = 0,
force_download: bool = False,
verify_files: bool = False,
Expand Down
10 changes: 4 additions & 6 deletions src/aac_datasets/datasets/wavcaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,21 @@

import logging
import os.path as osp

from pathlib import Path
from typing import Callable, ClassVar, List, Optional, Union
from typing import Any, Callable, ClassVar, List, Optional, Union

from torch import Tensor
from typing_extensions import TypedDict

from aac_datasets.datasets.base import AACDataset
from aac_datasets.datasets.functional.wavcaps import (
WavCapsCard,
load_wavcaps_dataset,
download_wavcaps_dataset,
_get_audio_subset_dpath,
download_wavcaps_dataset,
load_wavcaps_dataset,
)
from aac_datasets.utils.globals import _get_root, _get_zip_path


pylog = logging.getLogger(__name__)


Expand Down Expand Up @@ -110,7 +108,7 @@ def __init__(
root: Union[str, Path, None] = None,
subset: str = WavCapsCard.DEFAULT_SUBSET,
download: bool = False,
transform: Optional[Callable] = None,
transform: Optional[Callable[[WavCapsItem], Any]] = None,
verbose: int = 0,
force_download: bool = False,
verify_files: bool = False,
Expand Down

0 comments on commit 1fbbd28

Please sign in to comment.