-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathTinyImageNet.py
88 lines (76 loc) · 3.68 KB
/
TinyImageNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import glob
from torch.utils.data import Dataset
from PIL import Image
EXTENSION = 'JPEG'
NUM_IMAGES_PER_CLASS = 500
CLASS_LIST_FILE = '/apdcephfs/share_1290939/jiaxiaojun/imagenet/tiny-imagenet-200/wnids.txt'
VAL_ANNOTATION_FILE = '/apdcephfs/share_1290939/jiaxiaojun/imagenet/tiny-imagenet-200/val/val_annotations.txt'
class TinyImageNet(Dataset):
"""Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`.
Parameters
----------
root: string
Root directory including `train`, `test` and `val` subdirectories.
split: string
Indicating which split to return as a data set.
Valid option: [`train`, `test`, `val`]
transform: torchvision.transforms
A (series) of valid transformation(s).
in_memory: bool
Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead.
"""
def __init__(self, root, split='train', transform=None, target_transform=None, in_memory=False):
self.root = os.path.expanduser(root)
self.split = split
self.transform = transform
self.target_transform = target_transform
self.in_memory = in_memory
self.split_dir = os.path.join(root, self.split)
self.image_paths = sorted(glob.iglob(os.path.join(self.split_dir, '**', '*.%s' % EXTENSION), recursive=True))
self.labels = {} # fname - label number mapping
self.images = [] # used for in-memory processing
# build class label - number mapping
with open(os.path.join(self.root, CLASS_LIST_FILE), 'r') as fp:
self.label_texts = sorted([text.strip() for text in fp.readlines()])
self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}
if self.split == 'train':
for label_text, i in self.label_text_to_number.items():
for cnt in range(NUM_IMAGES_PER_CLASS):
self.labels['%s_%d.%s' % (label_text, cnt, EXTENSION)] = i
elif self.split == 'val':
with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), 'r') as fp:
for line in fp.readlines():
terms = line.split('\t')
file_name, label_text = terms[0], terms[1]
self.labels[file_name] = self.label_text_to_number[label_text]
# read all images into torch tensor in memory to minimize disk IO overhead
if self.in_memory:
self.images = [self.read_image(path) for path in self.image_paths]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
file_path = self.image_paths[index]
if self.in_memory:
img = self.images[index]
else:
img = self.read_image(file_path)
if self.split == 'test':
return img
else:
# file_name = file_path.split('/')[-1]
return img, self.labels[os.path.basename(file_path)]
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = self.split
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def read_image(self, path):
img = Image.open(path)
return self.transform(img) if self.transform else img