-
Notifications
You must be signed in to change notification settings - Fork 0
/
notmnist_setup.py
102 lines (83 loc) · 3.66 KB
/
notmnist_setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#modified from https://github.com/sahagobinda/GPM/blob/main/dataloader/five_datasets.py
import os,sys
import os.path
import numpy as np
import torch
import torch.utils.data
from torchvision import datasets,transforms
import urllib.request
from PIL import Image
import pickle
########################################################################################################################
class notMNIST(torch.utils.data.Dataset):
"""The notMNIST dataset is a image recognition dataset of font glypyhs for the letters A through J useful with simple neural networks. It is quite similar to the classic MNIST dataset of handwritten digits 0 through 9.
Args:
root (string): Root directory of dataset where directory ``Traffic signs`` exists.
split (string): One of {'train', 'test'}.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
"""
def __init__(self, root, train=True,transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.filename = "notmnist.zip"
self.url = "https://github.com/nkundiushuti/notmnist_convert/blob/master/notmnist.zip?raw=true"
fpath = os.path.join(root, self.filename)
if not os.path.isfile(fpath):
if not download:
raise RuntimeError('Dataset not found. You can use download=True to download it')
else:
print('Downloading from '+self.url)
self.download()
training_file = 'notmnist_train.pkl'
testing_file = 'notmnist_test.pkl'
if train:
with open(os.path.join(root,training_file),'rb') as f:
# u = pickle._Unpickler(f)
# u.encoding = 'latin1'
# train = u.load()
train = pickle.load(f)
self.data = train['features'].astype(np.uint8)
self.labels = train['labels'].astype(np.uint8)
else:
with open(os.path.join(root,testing_file),'rb') as f:
# u = pickle._Unpickler(f)
# u.encoding = 'latin1'
# test = u.load()
test = pickle.load(f)
self.data = test['features'].astype(np.uint8)
self.labels = test['labels'].astype(np.uint8)
def __getitem__(self, index):
"""
Args: index (int): Index
Returns: tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img[0])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.data)
def download(self):
import errno
root = os.path.expanduser(self.root)
fpath = os.path.join(root, self.filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
urllib.request.urlretrieve(self.url, fpath)
import zipfile
zip_ref = zipfile.ZipFile(fpath, 'r')
zip_ref.extractall(root)
zip_ref.close()