-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_cifar_100.py
67 lines (60 loc) · 2.8 KB
/
load_cifar_100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torchvision.transforms as transforms
def unpickle(file):
with open(file, 'rb') as fo:
data = pickle.load(fo, encoding='bytes')
return data
def load_cifar_100_data(data_dir, negatives=False):
transform = transforms.Compose([
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
meta_data_dict = unpickle(data_dir + "/meta")
cifar_label_names = meta_data_dict[b'fine_label_names']
cifar_label_names = np.array(cifar_label_names)
cifar_train_data_dict = unpickle(data_dir + "/train")
cifar_train_data = cifar_train_data_dict[b'data']
cifar_train_filenames = cifar_train_data_dict[b'filenames']
cifar_train_labels = cifar_train_data_dict[b'fine_labels']
cifar_train_data = cifar_train_data.reshape((len(cifar_train_data), 3, 32, 32))
cifar_test_data_dict = unpickle(data_dir + "/test")
cifar_test_data = cifar_test_data_dict[b'data']
cifar_test_filenames = cifar_test_data_dict[b'filenames']
cifar_test_labels = cifar_test_data_dict[b'fine_labels']
cifar_test_data = cifar_test_data.reshape((len(cifar_test_data), 3, 32, 32))
if negatives:
cifar_test_data = cifar_test_data.transpose(0, 2, 3, 1).astype(np.float32)
else:
cifar_test_data = np.rollaxis(cifar_test_data, 1, 4)
cifar_test_filenames = np.array(cifar_test_filenames)
cifar_test_labels = np.array(cifar_test_labels)
return cifar_train_data, cifar_train_filenames, cifar_train_labels, \
cifar_test_data, cifar_test_filenames, cifar_test_labels, cifar_label_names
if __name__ == "__main__":
"""show it works"""
cifar_100_dir = "./data/cifar100"#'cifar10'
train_data, train_filenames, train_labels, test_data, test_filenames, test_labels, label_names = \
load_cifar_100_data(cifar_100_dir)
print("Train data: ", train_data.shape)
print("Train filenames: ", train_filenames.shape)
print("Train labels: ", train_labels.shape)
print("Test data: ", test_data.shape)
print("Test filenames: ", test_filenames.shape)
print("Test labels: ", test_labels.shape)
print("Label names: ", label_names.shape)
# Don't forget that the label_names and filesnames are in binary and need conversion if used.
# display some random training images in a 25x25 grid
num_plot = 5
f, ax = plt.subplots(num_plot, num_plot)
for m in range(num_plot):
for n in range(num_plot):
idx = np.random.randint(0, train_data.shape[0])
ax[m, n].imshow(train_data[idx])
ax[m, n].get_xaxis().set_visible(False)
ax[m, n].get_yaxis().set_visible(False)
f.subplots_adjust(hspace=0.1)
f.subplots_adjust(wspace=0)
plt.show()