Skip to content

Commit

Permalink
Merge pull request #20 from yu4u/feature-datagen
Browse files Browse the repository at this point in the history
Feature datagen
  • Loading branch information
yu4u authored Nov 9, 2017
2 parents 9ce4e7d + 5fd28bb commit 5166a34
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 12 deletions.
54 changes: 44 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
This is a Keras implementation of a CNN for estimating age and gender from a face image [1, 2].
In training, [the IMDB-WIKI dataset](https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/) is used.


## Dependencies
- Python3.5+
- Keras2.0+
Expand Down Expand Up @@ -75,20 +74,40 @@ Trained weight files are stored as `checkpoints/weights.*.hdf5` for each epoch i
```sh
usage: train.py [-h] --input INPUT [--batch_size BATCH_SIZE]
[--nb_epochs NB_EPOCHS] [--depth DEPTH] [--width WIDTH]
[--validation_split VALIDATION_SPLIT]
[--validation_split VALIDATION_SPLIT] [--aug]

This script trains the CNN model for age and gender estimation.

optional arguments:
-h, --help show this help message and exit
--input INPUT, -i INPUT path to input database mat file (default: None)
--batch_size BATCH_SIZE batch size (default: 32)
--nb_epochs NB_EPOCHS number of epochs (default: 30)
--depth DEPTH depth of network (should be 10, 16, 22, 28, ...) (default: 16)
--width WIDTH width of network (default: 8)
--validation_split VALIDATION_SPLIT validation split ratio (default: 0.1)
-h, --help show this help message and exit
--input INPUT, -i INPUT
path to input database mat file (default: None)
--batch_size BATCH_SIZE
batch size (default: 32)
--nb_epochs NB_EPOCHS
number of epochs (default: 30)
--depth DEPTH depth of network (should be 10, 16, 22, 28, ...)
(default: 16)
--width WIDTH width of network (default: 8)
--validation_split VALIDATION_SPLIT
validation split ratio (default: 0.1)
--aug use data augmentation if set true (default: False)
```
#### Train network with recent data augmentation methods
Recent data augmentation methods, mixup [3] and Random Erasing [4],
can be used with standard data augmentation by `--aug` option in training:
```bash
python3 train.py --input data/imdb_db.mat --aug
```
Please refer to [this repository](https://github.com/yu4u/mixup-generator) for implementation details.
I confirmed that data augmentation enables us to avoid overfitting
and improves validation loss.
#### Use the trained network
```sh
Expand Down Expand Up @@ -117,11 +136,22 @@ Please use the best model among `checkpoints/weights.*.hdf5` for `WEIGHT_FILE` i
python3 plot_history.py --input models/history_16_8.h5
```
##### Results without data augmentation
<img src="https://github.com/yu4u/age-gender-estimation/wiki/images/loss.png" width="400px">
<img src="https://github.com/yu4u/age-gender-estimation/wiki/images/accuracy.png" width="400px">
##### Results with data augmentation
The best val_loss was improved from 3.969 to 3.731:
- Without data augmentation: 3.969
- With standard data augmentation: 3.799
- With mixup and random erasing: 3.731
<img src="fig/loss.png" width="480px">
We can see that, with data augmentation,
overfitting did not occur even at very small learning rates (epoch > 15).
<img src="https://github.com/yu4u/age-gender-estimation/wiki/images/accuracy.png" width="400px">
## Network architecture
In [the original paper](https://www.vision.ee.ethz.ch/en/publications/papers/articles/eth_biwi_01299.pdf) [1, 2], the pretrained VGG network is adopted.
Expand All @@ -141,3 +171,7 @@ Trained on imdb, tested on wiki.
[2] R. Rothe, R. Timofte, and L. V. Gool, "Deep expectation of real and apparent age from a single image
without facial landmarks," IJCV, 2016.
[3] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz, "mixup: Beyond Empirical Risk Minimization," in arXiv:1710.09412, 2017.
[4] Z. Zhong, L. Zheng, G. Kang, S. Li, and Y. Yang, "Random Erasing Data Augmentation," in arXiv:1708.04896, 2017.
Binary file added fig/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions mixup_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np


class MixupGenerator():
def __init__(self, X_train, y_train, batch_size=32, alpha=0.2, shuffle=True, datagen=None):
self.X_train = X_train
self.y_train = y_train
self.batch_size = batch_size
self.alpha = alpha
self.shuffle = shuffle
self.sample_num = len(X_train)
self.datagen = datagen

def __call__(self):
while True:
indexes = self.__get_exploration_order()
itr_num = int(len(indexes) // (self.batch_size * 2))

for i in range(itr_num):
batch_ids = indexes[i * self.batch_size * 2:(i + 1) * self.batch_size * 2]
X, y = self.__data_generation(batch_ids)

yield X, y

def __get_exploration_order(self):
indexes = np.arange(self.sample_num)

if self.shuffle:
np.random.shuffle(indexes)

return indexes

def __data_generation(self, batch_ids):
_, h, w, c = self.X_train.shape
l = np.random.beta(self.alpha, self.alpha, self.batch_size)
X_l = l.reshape(self.batch_size, 1, 1, 1)
y_l = l.reshape(self.batch_size, 1)

X1 = self.X_train[batch_ids[:self.batch_size]]
X2 = self.X_train[batch_ids[self.batch_size:]]
X = X1 * X_l + X2 * (1 - X_l)

if self.datagen:
for i in range(self.batch_size):
X[i] = self.datagen.random_transform(X[i])
X[i] = self.datagen.standardize(X[i])

if isinstance(self.y_train, list):
y = []

for y_train_ in self.y_train:
y1 = y_train_[batch_ids[:self.batch_size]]
y2 = y_train_[batch_ids[self.batch_size:]]
y.append(y1 * y_l + y2 * (1 - y_l))
else:
y1 = self.y_train[batch_ids[:self.batch_size]]
y2 = self.y_train[batch_ids[self.batch_size:]]
y = y1 * y_l + y2 * (1 - y_l)

return X, y
28 changes: 28 additions & 0 deletions random_eraser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np


def get_random_eraser(p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255):
def eraser(input_img):
img_h, img_w, _ = input_img.shape
p_1 = np.random.rand()

if p_1 > p:
return input_img

while True:
s = np.random.uniform(s_l, s_h) * img_h * img_w
r = np.random.uniform(r_1, r_2)
w = int(np.sqrt(s / r))
h = int(np.sqrt(s * r))
left = np.random.randint(0, img_w)
top = np.random.randint(0, img_h)

if left + w <= img_w and top + h <= img_h:
break

c = np.random.uniform(v_l, v_h)
input_img[top:top + h, left:left + w, :] = c

return input_img

return eraser
40 changes: 38 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import logging
import argparse
import os
import numpy as np
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import SGD
from keras.utils import np_utils
from wide_resnet import WideResNet
from utils import mk_dir, load_data
from keras.preprocessing.image import ImageDataGenerator
from mixup_generator import MixupGenerator
from random_eraser import get_random_eraser

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -40,6 +44,8 @@ def get_args():
help="width of network")
parser.add_argument("--validation_split", type=float, default=0.1,
help="validation split ratio")
parser.add_argument("--aug", action="store_true",
help="use data augmentation if set true")
args = parser.parse_args()
return args

Expand All @@ -52,6 +58,7 @@ def main():
depth = args.depth
k = args.width
validation_split = args.validation_split
use_augmentation = args.aug

logging.debug("Loading data...")
image, gender, age, _, image_size, _ = load_data(input_path)
Expand Down Expand Up @@ -83,8 +90,37 @@ def main():
]

logging.debug("Running training...")
hist = model.fit(X_data, [y_data_g, y_data_a], batch_size=batch_size, epochs=nb_epochs, callbacks=callbacks,
validation_split=validation_split)

data_num = len(X_data)
indexes = np.arange(data_num)
np.random.shuffle(indexes)
X_data = X_data[indexes]
y_data_g = y_data_g[indexes]
y_data_a = y_data_a[indexes]
train_num = int(data_num * (1 - validation_split))
X_train = X_data[:train_num]
X_test = X_data[train_num:]
y_train_g = y_data_g[:train_num]
y_test_g = y_data_g[train_num:]
y_train_a = y_data_a[:train_num]
y_test_a = y_data_a[train_num:]

if use_augmentation:
datagen = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
preprocessing_function=get_random_eraser(v_l=0, v_h=255))
training_generator = MixupGenerator(X_train, [y_train_g, y_train_a], batch_size=batch_size, alpha=0.2,
datagen=datagen)()
hist = model.fit_generator(generator=training_generator,
steps_per_epoch=train_num // batch_size,
validation_data=(X_test, [y_test_g, y_test_a]),
epochs=nb_epochs, verbose=1,
callbacks=callbacks)
else:
hist = model.fit(X_train, [y_train_g, y_train_a], batch_size=batch_size, epochs=nb_epochs, callbacks=callbacks,
validation_data=(X_test, [y_test_g, y_test_a]))

logging.debug("Saving weights...")
model.save_weights(os.path.join("models", "WRN_{}_{}.h5".format(depth, k)), overwrite=True)
Expand Down

0 comments on commit 5166a34

Please sign in to comment.