Skip to content

Commit

Permalink
Mod: Remove old code and update docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Aug 10, 2023
1 parent 845303e commit a42eba0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
29 changes: 24 additions & 5 deletions src/aac_datasets/datasets/audiocaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _prepare_audiocaps_dataset(
start_time,
duration=audio_duration,
sr=sr,
youtube_dl_path=ytdl_path,
ytdl_path=ytdl_path,
ffmpeg_path=ffmpeg_path,
n_channels=n_channels,
)
Expand Down Expand Up @@ -694,15 +694,15 @@ def _download_and_extract_from_youtube(
n_channels: int = 1,
target_format: str = "flac",
acodec: str = "flac",
youtube_dl_path: str = "youtube-dl",
ytdl_path: str = "youtube-dl",
ffmpeg_path: str = "ffmpeg",
) -> bool:
"""Download audio from youtube with youtube-dl and ffmpeg."""

# Get audio download link with youtube-dl
link = f"https://www.youtube.com/watch?v={youtube_id}"
link = _get_youtube_link(youtube_id)
get_url_command = [
youtube_dl_path,
ytdl_path,
"--youtube-skip-dash-manifest",
"-g",
link,
Expand Down Expand Up @@ -737,7 +737,7 @@ def _download_and_extract_from_youtube(
str(start_time),
"-t",
str(duration),
# Resample to 16 kHz
# Resample to a specific rate (default to 32 kHz)
"-ar",
str(sr),
# Compute mean of 2 channels
Expand Down Expand Up @@ -777,6 +777,25 @@ def _check_file(fpath: str, expected_sr: Optional[int]) -> bool:
return True


def _get_youtube_link(youtube_id: str, start_time: Optional[int]) -> str:
link = f"https://www.youtube.com/watch?v={youtube_id}"
if start_time is None:
return link
else:
return f"{link}&t={start_time}s"


def _get_youtube_link_embed(
youtube_id: str, start_time: Optional[int], duration: float = 10.0
) -> str:
link = f"https://www.youtube.com/embed/{youtube_id}"
if start_time is None:
return link
else:
end_time = start_time + duration
return f"{link}?start={start_time}&end={end_time}"


# Audio directory names per subset
_AUDIOCAPS_AUDIO_DNAMES = {
"train": "train",
Expand Down
35 changes: 23 additions & 12 deletions src/aac_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ class AACDataset(Generic[ItemType], Dataset[ItemType]):
# Initialization
def __init__(
self,
raw_data: Dict[str, List[Any]],
transform: Optional[Callable],
column_names: Iterable[str],
flat_captions: bool,
sr: Optional[int],
verbose: int,
raw_data: Optional[Dict[str, List[Any]]] = None,
transform: Optional[Callable] = None,
column_names: Optional[Iterable[str]] = None,
flat_captions: bool = False,
sr: Optional[int] = None,
verbose: int = 0,
) -> None:
if raw_data is None:
raw_data = {}
if column_names is None:
column_names = raw_data.keys()
column_names = list(column_names)

if len(raw_data) > 1:
Expand All @@ -77,17 +81,18 @@ def __init__(

@staticmethod
def new_empty() -> "AACDataset":
"""Create a new empty dataset."""
return AACDataset({}, None, (), False, None, 0)

# Properties
@property
def all_columns(self) -> List[str]:
"""The name of each column of the dataset."""
"""The name of all columns of the dataset."""
return list(self._raw_data | self._post_columns_fns)

@property
def column_names(self) -> List[str]:
"""The name of each column of the dataset."""
"""The name of all selected column of the dataset."""
return self._columns

@property
Expand Down Expand Up @@ -217,15 +222,19 @@ def at(
)

def has_raw_column(self, column: str) -> bool:
"""Returns True if column name exists in raw data."""
return column in self._raw_data

def has_post_column(self, column: str) -> bool:
"""Returns True if column name exists in post processed data."""
return column in self._post_columns_fns

def has_column(self, column: str) -> bool:
"""Returns True if column name exists in data."""
return self.has_raw_column(column) or self.has_post_column(column)

def remove_column(self, column: str) -> Union[List[Any], Callable]:
"""Removes a column from this dataset."""
if column in self._raw_data:
column_data = self._raw_data.pop(column, [])
return column_data
Expand All @@ -241,6 +250,7 @@ def rename_column(
new_column: str,
allow_replace: bool = False,
) -> None:
"""Renames a column from this dataset."""
column_data_or_fn = self.remove_column(old_column)

if isinstance(column_data_or_fn, List):
Expand All @@ -258,6 +268,7 @@ def add_raw_column(
column_data: List[Any],
allow_replace: bool = False,
) -> None:
"""Add a new raw column to this dataset."""
if not allow_replace and column in self._raw_data:
raise ValueError(
f"Column '{column}' already exists. Please choose another name or set allow_replace arg to True."
Expand All @@ -272,6 +283,7 @@ def add_post_column(
load_fn: Callable[[Any, int], Any],
allow_replace: bool = False,
) -> None:
"""Add a new post-processed column to this dataset."""
if not allow_replace and column in self._post_columns_fns:
raise ValueError(
f"Column '{column}' already exists in {self} and found argument allow_replace={allow_replace}."
Expand All @@ -283,6 +295,7 @@ def add_post_columns(
post_columns_fns: Dict[str, Callable[[Any, int], Any]],
allow_replace: bool = False,
) -> None:
"""Add several new post-processed columns to this dataset."""
for name, load_fn in post_columns_fns.items():
self.add_post_column(name, load_fn, allow_replace)

Expand All @@ -291,6 +304,7 @@ def load_post_column(
column: str,
allow_replace: bool = False,
) -> Callable[[Any, int], Any]:
"""Load all data from a post-column data into raw data."""
if column not in self._post_columns_fns:
raise ValueError(f"Invalid argument column={column}.")

Expand Down Expand Up @@ -373,14 +387,11 @@ def _check_columns(self, columns: List[str]) -> None:
msg = f"Invalid argument columns={columns}. (found {len(invalid_columns)} invalids column names for {self.__class__.__name__}: {invalid_columns})"
raise ValueError(msg)

invalid_columns = [name for name in columns if not self._can_be_loaded(name)]
invalid_columns = [name for name in columns if not self.has_column(name)]
if len(invalid_columns) > 0:
msg = f"Invalid argument columns={columns}. (found {len(invalid_columns)} invalids column names for {self.__class__.__name__}: {invalid_columns})"
raise ValueError(msg)

def _can_be_loaded(self, column: str) -> bool:
return self.has_raw_column(column) or self.has_post_column(column)

def _flat_raw_data(self) -> None:
raw_data, _ = _flat_raw_data(self._raw_data)
self._raw_data = raw_data
Expand Down

0 comments on commit a42eba0

Please sign in to comment.