Skip to content

Commit

Permalink
Add/Fix: Unittests for flat raw data func, fix flat raw data.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Aug 5, 2023
1 parent e089277 commit f2140d6
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 23 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/python-package-pip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ jobs:
- name: Print install info
run: |
aac-datasets-info
- name: Test with pytest
run: |
python -m pytest -v
6 changes: 6 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[pytest]
testpaths = tests/
filterwarnings =
error
ignore::FutureWarning
ignore::DeprecationWarning
50 changes: 43 additions & 7 deletions src/aac_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self._verbose = verbose

self._post_columns_fns = {}
self._flat_indices = []

if self._flat_captions:
self._flat_raw_data()
Expand All @@ -91,7 +92,7 @@ def column_names(self) -> List[str]:

@property
def flat_captions(self) -> bool:
"""The name of each column of the dataset."""
"""Returns true if captions has been flattened."""
return self._flat_captions

@property
Expand Down Expand Up @@ -381,7 +382,8 @@ def _can_be_loaded(self, column: str) -> bool:
return self.has_raw_column(column) or self.has_post_column(column)

def _flat_raw_data(self) -> None:
self._raw_data = _flat_raw_data(self._raw_data)
raw_data, _ = _flat_raw_data(self._raw_data)
self._raw_data = raw_data

def _load_auto_value(self, column: str, idx: int) -> Any:
if column in self._post_columns_fns:
Expand Down Expand Up @@ -442,12 +444,14 @@ def _load_sr(self, idx: int) -> int:
def _flat_raw_data(
raw_data: Dict[str, List[Any]],
caps_column: str = "captions",
) -> Dict[str, List[Any]]:
) -> Tuple[Dict[str, List[Any]], List[int]]:
if caps_column not in raw_data:
raise ValueError(f"Cannot flat raw data without '{caps_column}' column.")
raise ValueError(
f"Cannot flat raw data without '{caps_column}' column. (found only columns {tuple(raw_data.keys())})"
)

mcaps: List[List[str]] = raw_data[caps_column]
raw_data_flat = {key: [] for key in raw_data.keys()}
mcaps = raw_data[caps_column]

for i, caps in enumerate(mcaps):
if len(caps) == 0:
Expand All @@ -456,7 +460,39 @@ def _flat_raw_data(
else:
for cap in caps:
for key in raw_data.keys():
if key == caps_column:
continue
raw_data_flat[key].append(raw_data[key][i])
raw_data_flat[caps_column] = [cap]

return raw_data_flat
# Overwrite cap
raw_data_flat[caps_column].append([cap])

sizes = [len(caps) for caps in mcaps]
return raw_data_flat, sizes


def _unflat_raw_data(
raw_data_flat: Dict[str, List[Any]],
sizes: List[int],
caps_column: str = "captions",
) -> Dict[str, List[Any]]:
if caps_column not in raw_data_flat:
raise ValueError(
f"Cannot flat raw data without '{caps_column}' column. (found only columns {tuple(raw_data.keys())})"
)

raw_data = {key: [] for key in raw_data_flat.keys()}

cumsize = 0
for size in sizes:
for key in raw_data.keys():
if key == caps_column:
caps = [
raw_data_flat[key][idx][0] for idx in range(cumsize, cumsize + size)
]
raw_data[key].append(caps)
else:
raw_data[key].append(raw_data_flat[key][cumsize])
cumsize += size

return raw_data
31 changes: 16 additions & 15 deletions src/aac_datasets/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
pylog = logging.getLogger(__name__)


_TRUE_VALUES = ("true", "1", "t")
_FALSE_VALUES = ("false", "0", "f")


def download_audiocaps(
root: str = ".",
verbose: int = 1,
Expand Down Expand Up @@ -115,14 +119,16 @@ def download_wavcaps(
return datasets


def _to_bool(s: str) -> bool:
s = s.lower()
if s in ("true",):
def _str_to_bool(s: str) -> bool:
s = str(s).strip().lower()
if s in _TRUE_VALUES:
return True
elif s in ("false",):
elif s in _FALSE_VALUES:
return False
else:
raise ValueError(f"Invalid argument value {s}. (not a boolean)")
raise ValueError(
f"Invalid argument {s=}. (expected one of {_TRUE_VALUES + _FALSE_VALUES})"
)


def _get_main_download_args() -> Namespace:
Expand All @@ -144,9 +150,8 @@ def _get_main_download_args() -> Namespace:
)
parser.add_argument(
"--force",
type=_to_bool,
type=_str_to_bool,
default=False,
choices=(False, True),
help="Force download of files, even if they are already downloaded.",
)

Expand All @@ -171,9 +176,8 @@ def _get_main_download_args() -> Namespace:
)
audiocaps_subparser.add_argument(
"--with_tags",
type=_to_bool,
type=_str_to_bool,
default=True,
choices=(False, True),
help="Download additional audioset tags corresponding to audiocaps audio.",
)
audiocaps_subparser.add_argument(
Expand All @@ -195,9 +199,8 @@ def _get_main_download_args() -> Namespace:
)
clotho_subparser.add_argument(
"--clean_archives",
type=_to_bool,
type=_str_to_bool,
default=False,
choices=(False, True),
help="Remove archives files after extraction.",
)
clotho_subparser.add_argument(
Expand All @@ -212,19 +215,17 @@ def _get_main_download_args() -> Namespace:
macs_subparser = subparsers.add_parser(MACSCard.NAME)
macs_subparser.add_argument(
"--clean_archives",
type=_to_bool,
type=_str_to_bool,
default=False,
choices=(False, True),
help="Remove archives files after extraction.",
)
# Note : MACS only have 1 subset, so we do not add MACS subsets arg

wavcaps_subparser = subparsers.add_parser(WavCapsCard.NAME)
wavcaps_subparser.add_argument(
"--clean_archives",
type=_to_bool,
type=_str_to_bool,
default=False,
choices=(False, True),
help="Remove archives files after extraction.",
)
wavcaps_subparser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion src/aac_datasets/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


HASH_TYPES = ("sha256", "md5")
DEFAULT_CHUNK_SIZE = 256 * 1024**2
DEFAULT_CHUNK_SIZE = 256 * 1024**2 # 256 MiB


def safe_rmdir(
Expand Down
32 changes: 32 additions & 0 deletions tests/test_datasets_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-

import unittest

from unittest import TestCase

from aac_datasets.datasets.base import _flat_raw_data, _unflat_raw_data


class TestDatasetBase(TestCase):
def test_flat_raw_data(self) -> None:
raw_data = {
"captions": [["a1", "a2", "a3"], ["b1"], ["c1", "c2"], []],
"idx": list(range(1, 5)),
}
expected_flat = {
"captions": [["a1"], ["a2"], ["a3"], ["b1"], ["c1"], ["c2"], []],
"idx": [1, 1, 1, 2, 3, 3, 4],
}
expected_sizes = [3, 1, 2, 0]

raw_data_flat_out, sizes_out = _flat_raw_data(raw_data, "captions")
raw_data_out = _unflat_raw_data(raw_data_flat_out, sizes_out)

self.assertDictEqual(raw_data_flat_out, expected_flat)
self.assertListEqual(sizes_out, expected_sizes)
self.assertDictEqual(raw_data_out, raw_data)


if __name__ == "__main__":
unittest.main()

0 comments on commit f2140d6

Please sign in to comment.