-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_generator.py
executable file
·66 lines (51 loc) · 2.06 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torch.utils.data import Dataset
import albumentations as A
from skimage import io, transform
class CatsDataset(Dataset):
"""This class loads image from dataset and performs image preprocessing.
Subclass Dataset class from torch.utils.data.
"""
def __init__(self, datapath, dataframe, augment=None):
"""
Initialize CatsDataset.
:datapath: str, path to the dataset
:dataframe: DataFrame, includes embeddings from Style-Pre-Net
:augment: bool, if True, augment data
"""
self.datapath = datapath
self.dataframe = dataframe
self.augment = augment
def __len__(self):
return len(self.datapath)
def __getitem__(self, idx):
image_path = self.datapath[idx]
row = self.dataframe[self.dataframe.path == f'{image_path}']
embedding = row['embedding'].values[0]
image = io.imread(image_path)
# split sketch and image
w = image.shape[1] // 2
sketch = image[:, :w, :]
image = image[:, w:, :]
# normalize data
sketch = ( sketch / 127.5) - 1
image = ( image / 127.5) - 1
# augment data
if self.augment:
aug = A.OneOf([A.HorizontalFlip(p=1), A.RandomSizedCrop(min_max_height=(230, 230), height=256, width=256, p=1)], p=0.5)
augmented = aug(image=image, mask=sketch)
image = augmented['image']
sketch = augmented['mask']
image = torch.FloatTensor(image)
sketch = torch.FloatTensor(sketch)
embedding = torch.FloatTensor(embedding)
# check image dimension. If [height, width], add channel dimmension. Permute dimension to [channel, height, width]
if len(image.shape) == 2:
image = torch.stack([image] * 3)
else:
image = image.permute(2, 0, 1)
if len(sketch.shape) == 2:
sketch = torch.stack([sketch] * 3)
else:
sketch = sketch.permute(2, 0, 1)
return sketch, image, embedding