Skip to content

Commit

Permalink
Add standard augmentations for CIFAR datasets
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Aug 16, 2024
1 parent 581f6f7 commit 2ae1f47
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
24 changes: 20 additions & 4 deletions ebtorch/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _determine_train_test_args_common(dataset_name: str, is_train: bool) -> dict
return {"train": True} if is_train else {"train": False}


def _dataloader_dispatcher(
def _dataloader_dispatcher( # NOSONAR
dataset: str,
data_root: str = data_root_literal,
batch_size_train: Optional[int] = None,
Expand All @@ -66,6 +66,7 @@ def _dataloader_dispatcher(
shuffle_test: bool = False,
dataset_kwargs: Optional[dict] = None,
dataloader_kwargs: Optional[dict] = None,
augment_train: bool = False,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
if dataset == "mnist":
dataset_fx = MNIST
Expand Down Expand Up @@ -112,7 +113,18 @@ def _dataloader_dispatcher(

os.makedirs(name=data_root, exist_ok=True)

transforms = Compose([ToTensor()])
eval_transforms = Compose([ToTensor()])
train_transforms = (
Compose(
[
RandomCrop(32, padding=4),
RandomHorizontalFlip(),
ToTensor(),
]
)
if (augment_train and ("cifar" in dataset))
else Compose([ToTensor()])
)

# Address dictionary mutability as default argument
dataset_kwargs: dict = {} if dataset_kwargs is None else dataset_kwargs
Expand All @@ -121,14 +133,14 @@ def _dataloader_dispatcher(
trainset = dataset_fx(
root=data_root,
**_determine_train_test_args_common(dataset, is_train=True),
transform=transforms,
transform=train_transforms,
download=True,
**dataset_kwargs,
)
testset = dataset_fx(
root=data_root,
**_determine_train_test_args_common(dataset, is_train=False),
transform=transforms,
transform=eval_transforms,
download=True,
**dataset_kwargs,
)
Expand Down Expand Up @@ -238,6 +250,7 @@ def cifarten_dataloader_dispatcher(
shuffle_test: bool = False,
dataset_kwargs: Optional[dict] = None,
dataloader_kwargs: Optional[dict] = None,
augment_train: bool = False,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
return _dataloader_dispatcher(
dataset="cifar10",
Expand All @@ -249,6 +262,7 @@ def cifarten_dataloader_dispatcher(
shuffle_test=shuffle_test,
dataset_kwargs=dataset_kwargs,
dataloader_kwargs=dataloader_kwargs,
augment_train=augment_train,
)


Expand All @@ -261,6 +275,7 @@ def cifarhundred_dataloader_dispatcher(
shuffle_test: bool = False,
dataset_kwargs: Optional[dict] = None,
dataloader_kwargs: Optional[dict] = None,
augment_train: bool = False,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
return _dataloader_dispatcher(
dataset="cifar100",
Expand All @@ -272,6 +287,7 @@ def cifarhundred_dataloader_dispatcher(
shuffle_test=shuffle_test,
dataset_kwargs=dataset_kwargs,
dataloader_kwargs=dataloader_kwargs,
augment_train=augment_train,
)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.26.3",
version="0.26.4",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand Down

0 comments on commit 2ae1f47

Please sign in to comment.