Skip to content

Commit

Permalink
Mod: Supports multiple sr in dataset class base and add setter for ve…
Browse files Browse the repository at this point in the history
…rbose level.
  • Loading branch information
Labbeti committed Apr 26, 2024
1 parent e119950 commit 30c5563
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
67 changes: 38 additions & 29 deletions src/aac_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
IndexType = Union[int, Iterable[int], Iterable[bool], Tensor, slice, None]
ColumnType = Union[str, Iterable[str], None]

_IDX_TYPES = ("int", "Iterable[int]", "Iterable[bool]", "Tensor", "slice", "None")


def _is_index(index: Any) -> TypeGuard[IndexType]:
return (
Expand All @@ -70,14 +72,16 @@ def __init__(
transform: Optional[Callable[[ItemType], Any]] = None,
column_names: Optional[Iterable[str]] = None,
flat_captions: bool = False,
sr: Optional[int] = None,
sr: Union[int, Iterable[int], None] = None,
verbose: int = 0,
) -> None:
if raw_data is None:
raw_data = {}
if column_names is None:
column_names = raw_data.keys()
column_names = list(column_names)
if isinstance(sr, Iterable):
sr = list(sr)

if len(raw_data) > 1:
size = len(next(iter(raw_data.values())))
Expand Down Expand Up @@ -149,7 +153,7 @@ def shape(self) -> Tuple[int, int]:
return len(self), len(self.column_names)

@property
def sr(self) -> Optional[int]:
def sr(self) -> Union[int, List[int], None]:
return self._sr

@property
Expand All @@ -173,6 +177,10 @@ def column_names(
def transform(self, transform: Optional[Callable[[ItemType], Any]]) -> None:
self._transform = transform

@verbose.setter
def verbose(self, verbose: int) -> None:
self._verbose = verbose

# Public methods
@overload
def at(self, index: int) -> ItemType:
Expand All @@ -184,17 +192,11 @@ def at(
) -> List:
...

@overload
def at(
self, index: Union[Iterable[int], Iterable[bool], slice, None]
) -> Dict[str, List]:
...

@overload
def at(
self,
index: Union[Iterable[int], Iterable[bool], slice, None],
column: Union[Iterable[str], None],
column: Union[Iterable[str], None] = None,
) -> Dict[str, List]:
...

Expand All @@ -215,15 +217,21 @@ def at(
"""
if index is None:
index = slice(None)

elif isinstance(index, Tensor):
if index.ndim not in (0, 1):
raise ValueError(
f"Invalid number of dimensions for index argument. (found index.ndim={index.ndim} but expected 0 or 1)"
)
elif index.is_floating_point():
raise TypeError(
"Invalid tensor dtype. (found floating-point tensor but expected integer tensor)"
)
if __debug__:
if index.ndim not in (0, 1):
raise ValueError(
f"Invalid number of dimensions for index argument. (found index.ndim={index.ndim} but expected 0 or 1)"
)
elif index.is_floating_point():
raise TypeError(
"Invalid tensor dtype. (found floating-point tensor but expected integer or bool tensor)"
)
elif index.is_complex():
raise TypeError(
"Invalid tensor dtype. (found complex tensor but expected integer or bool tensor)"
)
index = index.tolist()

if column is None:
Expand All @@ -248,8 +256,7 @@ def at(
f"The length of the mask ({len(index)}) does not match the length of the dataset ({len(self)})."
)
index = [i for i, idx_i in enumerate(index) if idx_i]

elif not is_iterable_int(index):
elif __debug__ and not is_iterable_int(index):
raise TypeError(
f"Invalid input type for index={index}. (expected Iterable[int], not Iterable[{index[0].__class__.__name__}])"
)
Expand All @@ -264,14 +271,13 @@ def at(
]
return values

if isinstance(index, int):
return self._load_online_value(column, index)
else:
IDX_TYPES = ("int", "Iterable[int]", "None", "slice", "Tensor")
if __debug__ and not isinstance(index, int):
raise TypeError(
f"Invalid argument type {type(index)}. (expected one of {IDX_TYPES})"
f"Invalid argument type {type(index)}. (expected one of {_IDX_TYPES})"
)

return self._load_online_value(column, index)

def has_raw_column(self, column: str) -> bool:
"""Returns True if column name exists in raw data."""
return column in self._raw_data
Expand Down Expand Up @@ -315,18 +321,18 @@ def rename_column(

def add_raw_column(
self,
column: str,
column_name: str,
column_data: List[Any],
allow_replace: bool = False,
) -> None:
"""Add a new raw column to this dataset."""
if not allow_replace and column in self._raw_data:
if not allow_replace and column_name in self._raw_data:
raise ValueError(
f"Column '{column}' already exists. Please choose another name or set allow_replace arg to True."
f"Column '{column_name}' already exists. Please choose another name or set allow_replace arg to True."
)
if len(self._raw_data) > 0 and len(column_data) != len(self):
raise ValueError(f"Invalid number of rows in column '{column}'.")
self._raw_data[column] = column_data
raise ValueError(f"Invalid number of rows in column '{column_name}'.")
self._raw_data[column_name] = column_data

def add_online_column(
self,
Expand Down Expand Up @@ -488,6 +494,9 @@ def _load_audio(self, index: int) -> Tensor:
audio_and_sr: Tuple[Tensor, int] = torchaudio.load(fpath) # type: ignore
audio, sr = audio_and_sr

if not __debug__:
return audio

# Sanity check
if audio.nelement() == 0:
raise RuntimeError(
Expand Down
1 change: 0 additions & 1 deletion src/aac_datasets/datasets/clotho.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class Clotho(AACDataset[ClothoItem]):
├── clotho_metadata_validation.csv
├── retrieval_audio_metadata.csv
└── retrieval_captions.csv
"""

# Common globals
Expand Down

0 comments on commit 30c5563

Please sign in to comment.