Skip to content

Commit

Permalink
Mod: Rename idx to index in all package to match the argument name of…
Browse files Browse the repository at this point in the history
… the pytorch Dataset class.
  • Loading branch information
Labbeti committed Feb 6, 2024
1 parent 4de2bb6 commit c276294
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 162 deletions.
9 changes: 3 additions & 6 deletions src/aac_datasets/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@
import logging
import os.path as osp
import random

from argparse import ArgumentParser, Namespace
from typing import Dict, Iterable, Union

import yaml

import aac_datasets

from aac_datasets.datasets.audiocaps import AudioCaps, AudioCapsCard
from aac_datasets.datasets.clotho import Clotho, ClothoCard
from aac_datasets.datasets.macs import MACS, MACSCard
from aac_datasets.datasets.wavcaps import WavCaps, WavCapsCard
from aac_datasets.utils.globals import get_default_root
from aac_datasets.download import _setup_logging

from aac_datasets.utils.globals import get_default_root

DATASETS_NAMES = (AudioCapsCard.NAME, ClothoCard.NAME, MACSCard.NAME, WavCapsCard.NAME)

Expand Down Expand Up @@ -66,8 +63,8 @@ def check_directory(
ds = ds_class(root, subset, verbose=0)
if len(ds) > 0:
# Try to load a random item
idx = random.randint(0, len(ds) - 1)
_item = ds[idx]
index = random.randint(0, len(ds) - 1)
ds[index]
found_dsets[subset] = ds

except RuntimeError:
Expand Down
31 changes: 10 additions & 21 deletions src/aac_datasets/datasets/audiocaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,13 @@

import logging
import os.path as osp

from pathlib import Path
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Union,
)
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union

import torch
import torchaudio

from torch import Tensor
from typing_extensions import TypedDict, NotRequired
from typing_extensions import NotRequired, TypedDict

try:
# To support torchaudio >= 2.1.0
Expand All @@ -30,12 +20,11 @@
from aac_datasets.datasets.base import AACDataset
from aac_datasets.datasets.functional.audiocaps import (
AudioCapsCard,
_get_audio_subset_dpath,
download_audiocaps_dataset,
load_audiocaps_dataset,
_get_audio_subset_dpath,
)
from aac_datasets.utils.globals import _get_root, _get_ffmpeg_path, _get_ytdlp_path

from aac_datasets.utils.globals import _get_ffmpeg_path, _get_root, _get_ytdlp_path

pylog = logging.getLogger(__name__)

Expand Down Expand Up @@ -284,10 +273,10 @@ def __repr__(self) -> str:
return f"{AudioCapsCard.PRETTY_NAME}({repr_str})"

# Private methods
def _load_audio(self, idx: int) -> Tensor:
if not self._raw_data["is_on_disk"][idx]:
def _load_audio(self, index: int) -> Tensor:
if not self._raw_data["is_on_disk"][index]:
return torch.empty((0,))
fpath = self.at(idx, "fpath")
fpath = self.at(index, "fpath")
audio, sr = torchaudio.load(fpath) # type: ignore

# Sanity check
Expand All @@ -302,9 +291,9 @@ def _load_audio(self, idx: int) -> Tensor:
)
return audio

def _load_audio_metadata(self, idx: int) -> AudioMetaData:
if not self._raw_data["is_on_disk"][idx]:
def _load_audio_metadata(self, index: int) -> AudioMetaData:
if not self._raw_data["is_on_disk"][index]:
return AudioMetaData(-1, -1, -1, -1, "unknown_encoding")
fpath = self.at(idx, "fpath")
fpath = self.at(index, "fpath")
audio_metadata = torchaudio.info(fpath) # type: ignore
return audio_metadata
125 changes: 63 additions & 62 deletions src/aac_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import logging
import os.path as osp

from typing import (
Any,
Callable,
Expand All @@ -20,7 +19,6 @@

import torchaudio
import tqdm

from typing_extensions import TypedDict

try:
Expand All @@ -32,7 +30,6 @@
from torch import Tensor
from torch.utils.data.dataset import Dataset


pylog = logging.getLogger(__name__)


Expand Down Expand Up @@ -154,26 +151,28 @@ def transform(self, transform: Optional[Callable]) -> None:

# Public methods
@overload
def at(self, idx: int) -> ItemType:
def at(self, index: int) -> ItemType:
...

@overload
def at(self, idx: Union[Iterable[int], slice, None], column: str) -> List:
def at(self, index: Union[Iterable[int], slice, None], column: str) -> List:
...

@overload
def at(
self, idx: Union[Iterable[int], slice, None], column: Union[Iterable[str], None]
self,
index: Union[Iterable[int], slice, None],
column: Union[Iterable[str], None],
) -> Dict[str, List]:
...

@overload
def at(self, idx: Any, column: Any) -> Any:
def at(self, index: Any, column: Any) -> Any:
...

def at(
self,
idx: Union[int, Iterable[int], None, slice] = None,
index: Union[int, Iterable[int], None, slice] = None,
column: Union[str, Iterable[str], None] = None,
) -> Any:
"""Get a specific data field.
Expand All @@ -182,63 +181,63 @@ def at(
:param column: The name(s) of the column. Can be any value of :meth:`~Clotho.columns`.
:returns: The field value. The type depends of the column.
"""
if idx is None:
idx = slice(None)
elif isinstance(idx, Tensor):
if idx.ndim not in (0, 1):
if index is None:
index = slice(None)
elif isinstance(index, Tensor):
if index.ndim not in (0, 1):
raise ValueError(
f"Invalid number of dimensions for idx argument. (found idx.ndim={idx.ndim} but expected 0 or 1)"
f"Invalid number of dimensions for index argument. (found index.ndim={index.ndim} but expected 0 or 1)"
)
elif idx.is_floating_point():
elif index.is_floating_point():
raise TypeError(
"Invalid tensor dtype. (found floating-point tensor but expected integer tensor)"
)
idx = idx.tolist()
index = index.tolist()

if column is None:
column = self.column_names

if not isinstance(column, str) and isinstance(column, Iterable):
return {column_i: self.at(idx, column_i) for column_i in column}
return {column_i: self.at(index, column_i) for column_i in column}

if isinstance(idx, (int, slice)) and (
if isinstance(index, (int, slice)) and (
column in self._raw_data.keys() and column not in self._online_fns
):
return self._raw_data[column][idx] # type: ignore
return self._raw_data[column][index] # type: ignore

if isinstance(idx, slice):
idx = range(len(self))[idx]
if isinstance(index, slice):
index = range(len(self))[index]

if isinstance(idx, Iterable):
idx = list(idx)
if all(isinstance(idx_i, bool) for idx_i in idx):
if len(idx) != len(self):
if isinstance(index, Iterable):
index = list(index)
if all(isinstance(idx_i, bool) for idx_i in index):
if len(index) != len(self):
raise IndexError(
f"The length of the mask ({len(idx)}) does not match the length of the dataset ({len(self)})."
f"The length of the mask ({len(index)}) does not match the length of the dataset ({len(self)})."
)
idx = [i for i, idx_i in enumerate(idx) if idx_i]
index = [i for i, idx_i in enumerate(index) if idx_i]

elif not all(isinstance(idx_i, int) for idx_i in idx):
elif not all(isinstance(idx_i, int) for idx_i in index):
raise TypeError(
f"Invalid input type for idx={idx}. (expected Iterable[int], not Iterable[{idx[0].__class__.__name__}])"
f"Invalid input type for index={index}. (expected Iterable[int], not Iterable[{index[0].__class__.__name__}])"
)

values = [
self.at(idx_i, column)
for idx_i in tqdm.tqdm(
idx,
index,
desc=f"Loading column '{column}'...",
disable=self._verbose < 2,
)
]
return values

if isinstance(idx, int):
return self._load_online_value(column, idx)
if isinstance(index, int):
return self._load_online_value(column, index)
else:
IDX_TYPES = ("int", "Iterable[int]", "None", "slice", "Tensor")
raise TypeError(
f"Invalid argument type {type(idx)}. (expected one of {IDX_TYPES})"
f"Invalid argument type {type(index)}. (expected one of {IDX_TYPES})"
)

def has_raw_column(self, column: str) -> bool:
Expand Down Expand Up @@ -342,40 +341,41 @@ def preload_online_column(

# Magic methods
@overload
def __getitem__(self, idx: int) -> ItemType:
def __getitem__(self, index: int) -> ItemType:
...

@overload
def __getitem__(self, idx: Tuple[Union[Iterable[int], slice, None], str]) -> List:
def __getitem__(self, index: Tuple[Union[Iterable[int], slice, None], str]) -> List:
...

@overload
def __getitem__(self, idx: Union[Iterable[int], slice, None]) -> Dict[str, List]:
def __getitem__(self, index: Union[Iterable[int], slice, None]) -> Dict[str, List]:
...

@overload
def __getitem__(
self, idx: Tuple[Union[Iterable[int], slice, None], Union[Iterable[str], None]]
self,
index: Tuple[Union[Iterable[int], slice, None], Union[Iterable[str], None]],
) -> Dict[str, List]:
...

@overload
def __getitem__(self, idx: Any) -> Any:
def __getitem__(self, index: Any) -> Any:
...

def __getitem__(self, idx: Any) -> Any:
def __getitem__(self, index: Any) -> Any:
if (
isinstance(idx, tuple)
and len(idx) == 2
and (isinstance(idx[1], (str, Iterable)) or idx[1] is None)
isinstance(index, tuple)
and len(index) == 2
and (isinstance(index[1], (str, Iterable)) or index[1] is None)
):
idx, column = idx
index, column = index
else:
column = None

item = self.at(idx, column)
item = self.at(index, column)
if (
isinstance(idx, int)
isinstance(index, int)
and (column is None or column == self._columns)
and self._transform is not None
):
Expand Down Expand Up @@ -421,17 +421,17 @@ def _unflat_raw_data(self) -> None:
raw_data = _unflat_raw_data(self._raw_data, self._sizes)
self._raw_data = raw_data

def _load_online_value(self, column: str, idx: int) -> Any:
def _load_online_value(self, column: str, index: int) -> Any:
if column in self._online_fns:
fn = self._online_fns[column]
return fn(self, idx)
return fn(self, index)
else:
raise ValueError(
f"Invalid argument column={column} at idx={idx}. (expected one of {self.all_columns})"
f"Invalid argument column={column} at index={index}. (expected one of {self.all_columns})"
)

def _load_audio(self, idx: int) -> Tensor:
fpath = self.at(idx, "fpath")
def _load_audio(self, index: int) -> Tensor:
fpath = self.at(index, "fpath")
audio_and_sr: Tuple[Tensor, int] = torchaudio.load(fpath) # type: ignore
audio, sr = audio_and_sr

Expand All @@ -447,33 +447,33 @@ def _load_audio(self, idx: int) -> Tensor:
)
return audio

def _load_audio_metadata(self, idx: int) -> AudioMetaData:
fpath = self.at(idx, "fpath")
def _load_audio_metadata(self, index: int) -> AudioMetaData:
fpath = self.at(index, "fpath")
audio_metadata = torchaudio.info(fpath) # type: ignore
return audio_metadata

def _load_duration(self, idx: int) -> float:
audio_metadata: AudioMetaData = self.at(idx, "audio_metadata")
def _load_duration(self, index: int) -> float:
audio_metadata: AudioMetaData = self.at(index, "audio_metadata")
duration = audio_metadata.num_frames / audio_metadata.sample_rate
return duration

def _load_fname(self, idx: int) -> str:
fpath = self.at(idx, "fpath")
def _load_fname(self, index: int) -> str:
fpath = self.at(index, "fpath")
fname = osp.basename(fpath)
return fname

def _load_num_channels(self, idx: int) -> int:
audio_metadata = self.at(idx, "audio_metadata")
def _load_num_channels(self, index: int) -> int:
audio_metadata = self.at(index, "audio_metadata")
num_channels = audio_metadata.num_channels
return num_channels

def _load_num_frames(self, idx: int) -> int:
audio_metadata = self.at(idx, "audio_metadata")
def _load_num_frames(self, index: int) -> int:
audio_metadata = self.at(index, "audio_metadata")
num_frames = audio_metadata.num_frames
return num_frames

def _load_sr(self, idx: int) -> int:
audio_metadata = self.at(idx, "audio_metadata")
def _load_sr(self, index: int) -> int:
audio_metadata = self.at(index, "audio_metadata")
sr = audio_metadata.sample_rate
return sr

Expand Down Expand Up @@ -525,7 +525,8 @@ def _unflat_raw_data(
for key in raw_data.keys():
if key == caps_column:
caps = [
raw_data_flat[key][idx][0] for idx in range(cumsize, cumsize + size)
raw_data_flat[key][index][0]
for index in range(cumsize, cumsize + size)
]
raw_data[key].append(caps)
else:
Expand Down
Loading

0 comments on commit c276294

Please sign in to comment.