diff --git a/data_loader.py b/data_loader.py index 1c8d078..b9ca636 100644 --- a/data_loader.py +++ b/data_loader.py @@ -67,7 +67,7 @@ def train_generator(batchsize): train_dataset = train_dataset.batch(batchsize) train_dataset = train_dataset.repeat() - train_dataset = train_dataset.shuffle(buffer_size=6) + train_dataset = train_dataset.shuffle(buffer_size=4) train_iterator = train_dataset.make_initializable_iterator() train_batch = train_iterator.get_next() @@ -93,7 +93,7 @@ def val_generator(batchsize): val_dataset = val_dataset.batch(batchsize) val_dataset = val_dataset.repeat() - val_dataset = val_dataset.shuffle(buffer_size=6) + val_dataset = val_dataset.shuffle(buffer_size=4) val_iterator = val_dataset.make_initializable_iterator() val_batch = val_iterator.get_next()