From 2ae1f47cfd2201d0d53c6d30e2e3e35cc4f6f2a6 Mon Sep 17 00:00:00 2001 From: Emanuele Ballarin Date: Fri, 16 Aug 2024 23:46:15 +0200 Subject: [PATCH] Add standard augmentations for CIFAR datasets Signed-off-by: Emanuele Ballarin --- ebtorch/data/datasets.py | 24 ++++++++++++++++++++---- setup.py | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/ebtorch/data/datasets.py b/ebtorch/data/datasets.py index e8fba5c..2e814bb 100644 --- a/ebtorch/data/datasets.py +++ b/ebtorch/data/datasets.py @@ -56,7 +56,7 @@ def _determine_train_test_args_common(dataset_name: str, is_train: bool) -> dict return {"train": True} if is_train else {"train": False} -def _dataloader_dispatcher( +def _dataloader_dispatcher( # NOSONAR dataset: str, data_root: str = data_root_literal, batch_size_train: Optional[int] = None, @@ -66,6 +66,7 @@ def _dataloader_dispatcher( shuffle_test: bool = False, dataset_kwargs: Optional[dict] = None, dataloader_kwargs: Optional[dict] = None, + augment_train: bool = False, ) -> Tuple[DataLoader, DataLoader, DataLoader]: if dataset == "mnist": dataset_fx = MNIST @@ -112,7 +113,18 @@ def _dataloader_dispatcher( os.makedirs(name=data_root, exist_ok=True) - transforms = Compose([ToTensor()]) + eval_transforms = Compose([ToTensor()]) + train_transforms = ( + Compose( + [ + RandomCrop(32, padding=4), + RandomHorizontalFlip(), + ToTensor(), + ] + ) + if (augment_train and ("cifar" in dataset)) + else Compose([ToTensor()]) + ) # Address dictionary mutability as default argument dataset_kwargs: dict = {} if dataset_kwargs is None else dataset_kwargs @@ -121,14 +133,14 @@ def _dataloader_dispatcher( trainset = dataset_fx( root=data_root, **_determine_train_test_args_common(dataset, is_train=True), - transform=transforms, + transform=train_transforms, download=True, **dataset_kwargs, ) testset = dataset_fx( root=data_root, **_determine_train_test_args_common(dataset, is_train=False), - transform=transforms, + transform=eval_transforms, download=True, **dataset_kwargs, ) @@ -238,6 +250,7 @@ def cifarten_dataloader_dispatcher( shuffle_test: bool = False, dataset_kwargs: Optional[dict] = None, dataloader_kwargs: Optional[dict] = None, + augment_train: bool = False, ) -> Tuple[DataLoader, DataLoader, DataLoader]: return _dataloader_dispatcher( dataset="cifar10", @@ -249,6 +262,7 @@ def cifarten_dataloader_dispatcher( shuffle_test=shuffle_test, dataset_kwargs=dataset_kwargs, dataloader_kwargs=dataloader_kwargs, + augment_train=augment_train, ) @@ -261,6 +275,7 @@ def cifarhundred_dataloader_dispatcher( shuffle_test: bool = False, dataset_kwargs: Optional[dict] = None, dataloader_kwargs: Optional[dict] = None, + augment_train: bool = False, ) -> Tuple[DataLoader, DataLoader, DataLoader]: return _dataloader_dispatcher( dataset="cifar100", @@ -272,6 +287,7 @@ def cifarhundred_dataloader_dispatcher( shuffle_test=shuffle_test, dataset_kwargs=dataset_kwargs, dataloader_kwargs=dataloader_kwargs, + augment_train=augment_train, ) diff --git a/setup.py b/setup.py index ab9ed93..053bcb0 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def read(fname): setup( name=PACKAGENAME, - version="0.26.3", + version="0.26.4", author="Emanuele Ballarin", author_email="emanuele@ballarin.cc", url="https://github.com/emaballarin/ebtorch",