Skip to content

Latest commit

 

History

History
70 lines (47 loc) · 2.38 KB

README.md

File metadata and controls

70 lines (47 loc) · 2.38 KB

TensorFlow Requirement: 1.x TensorFlow 2 Not Supported

Train Wide-ResNet, Shake-Shake and ShakeDrop models on CIFAR-10 and CIFAR-100 dataset with AutoAugment.

The CIFAR-10/CIFAR-100 data can be downloaded from: https://www.cs.toronto.edu/~kriz/cifar.html. Use the Python version instead of the binary version.

The code replicates the results from Tables 1 and 2 on CIFAR-10/100 with the following models: Wide-ResNet-28-10, Shake-Shake (26 2x32d), Shake-Shake (26 2x96d) and PyramidNet+ShakeDrop.

Related papers:

AutoAugment: Learning Augmentation Policies from Data

https://arxiv.org/abs/1805.09501

Wide Residual Networks

https://arxiv.org/abs/1605.07146

Shake-Shake regularization

https://arxiv.org/abs/1705.07485

ShakeDrop regularization

https://arxiv.org/abs/1802.02375

Settings:

CIFAR-10 Model Learning Rate Weight Decay Num. Epochs Batch Size
Wide-ResNet-28-10 0.1 5e-4 200 128
Shake-Shake (26 2x32d) 0.01 1e-3 1800 128
Shake-Shake (26 2x96d) 0.01 1e-3 1800 128
PyramidNet + ShakeDrop 0.05 5e-5 1800 64

Prerequisite:

  1. Install TensorFlow. Be sure to run the code using python2 and not python3.

  2. Download CIFAR-10/CIFAR-100 dataset.

curl -o cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

How to run:

# cd to the your workspace.
# Specify the directory where dataset is located using the data_path flag.
# Note: User can split samples from training set into the eval set by changing train_size and validation_size.

# For example, to train the Wide-ResNet-28-10 model on a GPU.
python train_cifar.py --model_name=wrn \
                      --checkpoint_dir=/tmp/training \
                      --data_path=/tmp/data \
                      --dataset='cifar10' \
                      --use_cpu=0

Contact for Issues