-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathdata.py
executable file
·86 lines (66 loc) · 2.62 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import print_function, division
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
class FaceScrub(Dataset):
def __init__(self, root, transform=None, target_transform=None, train=True):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
input = np.load(os.path.join(self.root, 'facescrub.npz'))
actor_images = input['actor_images']
actor_labels = input['actor_labels']
actress_images = input['actress_images']
actress_labels = input['actress_labels']
data = np.concatenate([actor_images, actress_images], axis=0)
labels = np.concatenate([actor_labels, actress_labels], axis=0)
v_min = data.min(axis=0)
v_max = data.max(axis=0)
data = (data - v_min) / (v_max - v_min)
np.random.seed(666)
perm = np.arange(len(data))
np.random.shuffle(perm)
data = data[perm]
labels = labels[perm]
if train:
self.data = data[0:int(0.8 * len(data))]
self.labels = labels[0:int(0.8 * len(data))]
else:
self.data = data[int(0.8 * len(data)):]
self.labels = labels[int(0.8 * len(data)):]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
class CelebA(Dataset):
def __init__(self, root, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
data = []
for i in range(10):
data.append(np.load(os.path.join(self.root, 'celebA_64_{}.npy').format(i + 1)))
data = np.concatenate(data, axis=0)
v_min = data.min(axis=0)
v_max = data.max(axis=0)
data = (data - v_min) / (v_max - v_min)
labels = np.array([0] * len(data))
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target