Skip to content

Commit

Permalink
Mod: Refactor collections functions, update ffmpeg_path and ytdl_path…
Browse files Browse the repository at this point in the history
… arguments for AC.
  • Loading branch information
Labbeti committed Aug 8, 2023
1 parent b630476 commit 72469eb
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 64 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,17 @@ This package has been developped for Ubuntu 20.04, and it is expected to work on

### External requirements (AudioCaps only)

The external requirements needed to download **AudioCaps** are **ffmpeg** and **youtube-dl**.
The external requirements needed to download **AudioCaps** are **ffmpeg** and **youtube-dl** (yt-dlp should work too).
These two programs can be download on Ubuntu using `sudo apt install ffmpeg youtube-dl`.

You can also override their paths for AudioCaps:
```python
from aac_datasets import AudioCaps
AudioCaps.FFMPEG_PATH = "/my/path/to/ffmpeg"
AudioCaps.YOUTUBE_DL_PATH = "/my/path/to/youtube_dl"
dataset = AudioCaps(root=".", download=True)
dataset = AudioCaps(
download=True,
ffmpeg_path="/my/path/to/ffmpeg",
ytdl_path="/my/path/to/youtube_dl",
)
```

## Download datasets
Expand Down
14 changes: 8 additions & 6 deletions src/aac_datasets/datasets/audiocaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ class AudioCaps(AACDataset[AudioCapsItem]):
Audio is a waveform tensor of shape (1, n_times) of 10 seconds max, sampled at 32kHz by default.
Target is a list of strings containing the captions.
The 'train' subset has only 1 caption per sample and 'val' and 'test' have 5 captions.
Download requires 'youtube-dl' and 'ffmpeg' commands.
You can change the default path with :attr:`~AudioCaps.YOUTUBE_DL_PATH` or :attr:`~AudioCaps.FFMPEG_PATH` global variables.
AudioCaps paper : https://www.aclweb.org/anthology/N19-1011.pdf
Expand Down Expand Up @@ -125,8 +123,6 @@ class AudioCaps(AACDataset[AudioCapsItem]):
AUDIO_DURATION: ClassVar[float] = 10.0
AUDIO_FORMAT: ClassVar[str] = "flac"
AUDIO_N_CHANNELS: ClassVar[int] = 1
FFMPEG_PATH: ClassVar[str] = "ffmpeg"
YOUTUBE_DL_PATH: ClassVar[str] = "youtube-dl"

# Initialization
def __init__(
Expand All @@ -140,6 +136,8 @@ def __init__(
exclude_removed_audio: bool = True,
with_tags: bool = False,
sr: int = 32_000,
ffmpeg_path: str = "ffmpeg",
ytdl_path: str = "youtube-dl",
) -> None:
"""
:param root: Dataset root directory.
Expand All @@ -164,6 +162,10 @@ def __init__(
:param sr: The sample rate used for audio files in the dataset (in Hz).
Since original YouTube videos are recorded in various settings, this parameter allow to download allow audio files with a specific sample rate.
defaults to 32000.
:param ffmpeg_path: Path to ffmpeg executable file.
defaults to "ffmpeg".
:param ytdl_path: Path to youtube-dl or ytdlp executable.
defaults to "youtube-dl".
"""
if subset not in AudioCapsCard.SUBSETS:
raise ValueError(
Expand All @@ -178,8 +180,8 @@ def __init__(
with_tags,
verbose,
AudioCaps.FORCE_PREPARE_DATA,
AudioCaps.YOUTUBE_DL_PATH,
AudioCaps.FFMPEG_PATH,
ytdl_path,
ffmpeg_path,
AudioCaps.AUDIO_FORMAT,
AudioCaps.AUDIO_DURATION,
AudioCaps.AUDIO_N_CHANNELS,
Expand Down
2 changes: 1 addition & 1 deletion src/aac_datasets/datasets/wavcaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing_extensions import TypedDict

from aac_datasets.datasets.base import AACDataset, DatasetCard
from aac_datasets.utils.collate import list_dict_to_dict_list
from aac_datasets.utils.collections import list_dict_to_dict_list
from aac_datasets.utils.download import safe_rmdir


Expand Down
23 changes: 13 additions & 10 deletions src/aac_datasets/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@ def download_audiocaps(
verbose: int = 1,
force: bool = False,
download: bool = True,
ffmpeg: str = "ffmpeg",
youtube_dl: str = "youtube-dl",
ffmpeg_path: str = "ffmpeg",
ytdl_path: str = "youtube-dl",
with_tags: bool = False,
subsets: Iterable[str] = AudioCapsCard.SUBSETS,
) -> Dict[str, AudioCaps]:
"""Download :class:`~aac_datasets.datasets.audiocaps.AudioCaps` dataset subsets."""
AudioCaps.FORCE_PREPARE_DATA = force
AudioCaps.FFMPEG_PATH = ffmpeg
AudioCaps.YOUTUBE_DL_PATH = youtube_dl

datasets = {}
for subset in subsets:
datasets[subset] = AudioCaps(
root, subset, download=download, verbose=verbose, with_tags=with_tags
root,
subset,
download=download,
verbose=verbose,
with_tags=with_tags,
ffmpeg_path=ffmpeg_path,
ytdl_path=ytdl_path,
)
return datasets

Expand Down Expand Up @@ -163,13 +166,13 @@ def _get_main_download_args() -> Namespace:

audiocaps_subparser = subparsers.add_parser(AudioCapsCard.NAME)
audiocaps_subparser.add_argument(
"--ffmpeg",
"--ffmpeg_path",
type=str,
default="ffmpeg",
help="Path to ffmpeg used to download audio from youtube.",
)
audiocaps_subparser.add_argument(
"--youtube_dl",
"--ytdl_path",
type=str,
default="youtube-dl",
help="Path to youtube-dl used to extract metadata from a youtube video.",
Expand Down Expand Up @@ -272,8 +275,8 @@ def _main_download() -> None:
verbose=args.verbose,
force=args.force,
download=True,
ffmpeg=args.ffmpeg,
youtube_dl=args.youtube_dl,
ffmpeg_path=args.ffmpeg_path,
ytdl_path=args.ytdl_path,
with_tags=args.with_tags,
subsets=args.subsets,
)
Expand Down
45 changes: 2 additions & 43 deletions src/aac_datasets/utils/collate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Any, Dict, List, TypeVar, Union
from typing import Any, Dict, List, Union

import torch

from torch import Tensor
from torch.nn import functional as F


T = TypeVar("T")
from aac.utils.collections import list_dict_to_dict_list


class BasicCollate:
Expand Down Expand Up @@ -98,43 +97,3 @@ def pad_last_dim(tensor: Tensor, target_length: int, pad_value: float) -> Tensor
"""
pad_len = max(target_length - tensor.shape[-1], 0)
return F.pad(tensor, [0, pad_len], value=pad_value)


def list_dict_to_dict_list(
lst: List[Dict[str, T]],
key_mode: str = "intersect",
) -> Dict[str, List[T]]:
"""Convert list of dicts to dict of lists.
:param lst: The list of dict to merge.
:param key_mode: Can be "same" or "intersect".
If "same", all the dictionaries must contains the same keys otherwise a ValueError will be raised.
If "intersect", only the intersection of all keys will be used in output.
:returns: The dictionary of lists.
"""
if len(lst) == 0:
return {}
keys = set(lst[0].keys())
if key_mode == "same":
if not all(keys == set(item.keys()) for item in lst[1:]):
raise ValueError("Invalid keys for batch.")
elif key_mode == "intersect":
keys = intersect_lists([list(item.keys()) for item in lst])
else:
KEY_MODES = ("same", "intersect")
raise ValueError(
f"Invalid argument key_mode={key_mode}. (expected one of {KEY_MODES})"
)

return {key: [item[key] for item in lst] for key in keys}


def intersect_lists(lst_of_lst: List[List[T]]) -> List[T]:
if len(lst_of_lst) <= 0:
return []
out = lst_of_lst[0]
for lst_i in lst_of_lst[1:]:
out = [name for name in out if name in lst_i]
if len(out) == 0:
break
return out
48 changes: 48 additions & 0 deletions src/aac_datasets/utils/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Dict, List, TypeVar


T = TypeVar("T")


def list_dict_to_dict_list(
lst: List[Dict[str, T]],
key_mode: str = "intersect",
) -> Dict[str, List[T]]:
"""Convert list of dicts to dict of lists.
:param lst: The list of dict to merge.
:param key_mode: Can be "same" or "intersect".
If "same", all the dictionaries must contains the same keys otherwise a ValueError will be raised.
If "intersect", only the intersection of all keys will be used in output.
:returns: The dictionary of lists.
"""
if len(lst) == 0:
return {}
keys = set(lst[0].keys())
if key_mode == "same":
if not all(keys == set(item.keys()) for item in lst[1:]):
raise ValueError("Invalid keys for batch.")
elif key_mode == "intersect":
keys = intersect_lists([list(item.keys()) for item in lst])
else:
KEY_MODES = ("same", "intersect")
raise ValueError(
f"Invalid argument key_mode={key_mode}. (expected one of {KEY_MODES})"
)

return {key: [item[key] for item in lst] for key in keys}


def intersect_lists(lst_of_lst: List[List[T]]) -> List[T]:
"""Performs intersection of elements in lists (like set intersection), but keep their original order."""
if len(lst_of_lst) <= 0:
return []
out = list(dict.fromkeys(lst_of_lst[0]))
for lst_i in lst_of_lst[1:]:
out = [name for name in out if name in lst_i]
if len(out) == 0:
break
return out
21 changes: 21 additions & 0 deletions tests/test_utils_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-

import unittest

from unittest import TestCase

from aac_datasets.utils.collections import intersect_lists


class TestCollections(TestCase):
def test_intersect_lists(self) -> None:
input_ = [["a", "b", "b", "c"], ["c", "d", "b", "a"], ["b", "a", "a", "e"]]
expected = ["a", "b"]

output = intersect_lists(input_)
self.assertListEqual(output, expected)


if __name__ == "__main__":
unittest.main()

0 comments on commit 72469eb

Please sign in to comment.