-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloader.py
60 lines (55 loc) · 1.97 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import torchvision.transforms as transforms
import torch
from torch.utils.data import Dataset, DataLoader
from glob import glob
from PIL import Image, ImageSequence
class MovingMNIST(Dataset):
def __init__(self, path='mnist_test_seq.npy'):
self.data = np.load(path) # 20 x 10000 x 64 x 64
def __len__(self):
return self.data.shape[1]
def __getitem__(self, index):
ret = self.data[:16, index, :, :] # D x H x W
ret = ret.astype(np.float32)
ret = (ret / 255) * 2 - 1
# C x D x H x W
return np.expand_dims(ret, axis=0)
class GIF(Dataset):
def __init__(self, path='gifs'):
self.paths = glob(f'{path}/*.gif')
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
tensors = torch.zeros(16, 3, 64, 64)
try:
gif = Image.open(path)
gif_iter = ImageSequence.Iterator(gif)
gifs = [img.convert('RGB') for img in gif_iter]
f = lambda m, n: [i*n//m + n//(2*m) for i in range(m)]
indice = f(16, len(gifs))
frames = [gifs[i] for i in indice]
for i, frame in enumerate(frames):
h, w = frame.size
hpad = (max(h, w) - h) // 2
wpad = (max(h, w) - w) // 2
pad = (hpad, wpad, hpad, wpad)
transform = transforms.Compose([
transforms.Pad(pad),
transforms.Scale([64, 64]),
transforms.ToTensor()
])
image_tensor = transform(frame)
tensors[i, :, :, :] = image_tensor
tensors = tensors * 2 - 1
except:
print('Error Occured While opening image')
pass
tensors = torch.transpose(tensors, 0, 1)
return tensors
if __name__ == '__main__':
dset = GIF()
loader = DataLoader(dset, batch_size=2)
for data in loader:
pass