-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmnist_m.py
128 lines (96 loc) · 3.95 KB
/
mnist_m.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import warnings
import torch
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
class MNISTM(VisionDataset):
"""MNIST-M Dataset.
"""
resources = [
('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_train.pt.tar.gz',
'191ed53db9933bd85cc9700558847391'),
('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_test.pt.tar.gz',
'e11cb4d7fff76d7ec588b1134907db59')
]
training_file = "mnist_m_train.pt"
test_file = "mnist_m_test.pt"
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
"""Init MNIST-M dataset."""
super(MNISTM, self).__init__(root, transform=transform, target_transform=target_transform)
self.train = train
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
print(os.path.join(self.processed_folder, data_file))
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.squeeze().numpy(), mode="RGB")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
"""Return size of dataset."""
return len(self.data)
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and
os.path.exists(os.path.join(self.processed_folder, self.test_file)))
def download(self):
"""Download the MNIST-M data."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
for url, md5 in self.resources:
filename = url.rpartition('/')[2]
download_and_extract_archive(url, download_root=self.raw_folder,
extract_root=self.processed_folder,
filename=filename, md5=md5)
print('Done!')
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")