-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdatasets.py
38 lines (31 loc) · 1.01 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
import torch
from torch.utils.data import Dataset
class PyroMultiMNIST(Dataset):
def __init__(self, path, train):
self.path = path
self.train = train
data = np.load(path, allow_pickle=True)
x = data['x']
y = data['y']
split = 50000
if train:
self.x, self.y = x[:split], y[:split]
else:
self.x, self.y = x[split:], y[split:]
def __getitem__(self, index):
"""
Returns (x, y), where x is (1, H, W) in range (0, 1),
y is a label dict with only a 'n_obj' key.
"""
# x: uint8, (1, H, W)
# y: label dict
x, y = self.x[index], self.y[index]
y = np.array(len(y))
x = x / 255.0
x, y = torch.from_numpy(x).float(), torch.from_numpy(y).float()
x = x[None]
y = {'n_obj': y} # label dict: compatible with multiobject dataloader
return x, y
def __len__(self):
return len(self.x)