forked from MichalDanielDobrzanski/DeepLearningPython
-
Notifications
You must be signed in to change notification settings - Fork 0
/
expand_mnist.py
60 lines (51 loc) · 1.94 KB
/
expand_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""expand_mnist.py
~~~~~~~~~~~~~~~~~~
Take the 50,000 MNIST training images, and create an expanded set of
250,000 images, by displacing each training image up, down, left and
right, by one pixel. Save the resulting file to
../data/mnist_expanded.pkl.gz.
Note that this program is memory intensive, and may not run on small
systems.
"""
from __future__ import print_function
#### Libraries
# Standard library
import cPickle
import gzip
import os.path
import random
# Third-party libraries
import numpy as np
print("Expanding the MNIST training set")
if os.path.exists("../data/mnist_expanded.pkl.gz"):
print("The expanded training set already exists. Exiting.")
else:
f = gzip.open("../data/mnist.pkl.gz", 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
expanded_training_pairs = []
j = 0 # counter
for x, y in zip(training_data[0], training_data[1]):
expanded_training_pairs.append((x, y))
image = np.reshape(x, (-1, 28))
j += 1
if j % 1000 == 0: print("Expanding image number", j)
# iterate over data telling us the details of how to
# do the displacement
for d, axis, index_position, index in [
(1, 0, "first", 0),
(-1, 0, "first", 27),
(1, 1, "last", 0),
(-1, 1, "last", 27)]:
new_img = np.roll(image, d, axis)
if index_position == "first":
new_img[index, :] = np.zeros(28)
else:
new_img[:, index] = np.zeros(28)
expanded_training_pairs.append((np.reshape(new_img, 784), y))
random.shuffle(expanded_training_pairs)
expanded_training_data = [list(d) for d in zip(*expanded_training_pairs)]
print("Saving expanded data. This may take a few minutes.")
f = gzip.open("../data/mnist_expanded.pkl.gz", "w")
cPickle.dump((expanded_training_data, validation_data, test_data), f)
f.close()