-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_dataset.py
executable file
·100 lines (86 loc) · 4.14 KB
/
custom_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from PIL import ImageOps, Image
import torchvision.transforms as transforms
import albumentations as A
class Dataset_(Dataset):
def __init__(self, data_dir, resized_size, is_train):
super(Dataset_, self).__init__()
self.data_dir = data_dir
self.resized_size = resized_size
self.is_train = is_train
self.rescaler = transforms.Resize((resized_size, resized_size), Image.LANCZOS)
self.to_tensor = transforms.ToTensor()
self.h_flip = transforms.RandomHorizontalFlip()
self.color_jitter = A.Compose([
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, always_apply=True),
])
self.perspective_transform1 = A.Perspective(scale=(0.05, 0.1), keep_size=True, fit_output=True, always_apply=True)
self.perspective_transform2 = A.Perspective(scale=(0.05, 0.1), keep_size=True, fit_output=False, always_apply=True)
self.drop_out = A.CoarseDropout(max_holes=1, max_height=0.5, max_width=0.5, min_holes=1, min_height=0.3, min_width=0.3, always_apply=True)
self.load_dataset()
def random_geometry_transform(self, image):
p = torch.rand(1)
if p < 0.5:
image = self.perspective_transform1(image=np.array(image))
else:
image = self.perspective_transform2(image=np.array(image))
return image
def random_appearance_transform(self, image):
p = torch.rand(1)
if p < 0.5:
image = self.random_drop_out(image)
else:
image = self.random_color_transform(image)
return image
def random_drop_out(self, image):
image = self.drop_out(image=np.array(image))
return image
def random_color_transform(self, image):
image = self.color_jitter(image=np.array(image))
return image
def load_dataset(self):
mode = "train"
root = os.path.join(self.data_dir, mode)
self.data = ImageFolder(root=root)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if self.is_train:
image, label = self.data[index]
width, height = image.size
if width == self.resized_size:
resized_image = image
else:
resized_image = self.rescaler(image)
resized_image = self.h_flip(resized_image)
geometry_change = resized_image.copy()
appearance_change = resized_image.copy()
resized_image = self.to_tensor(resized_image)
geometry_change = self.random_geometry_transform(geometry_change)
appearance_change = self.random_appearance_transform(appearance_change)
geometry_change = Image.fromarray(geometry_change['image'])
appearance_change = Image.fromarray(appearance_change['image'])
geometry_change = self.to_tensor(geometry_change)
appearance_change = self.to_tensor(appearance_change)
resized_image = resized_image * 2.0 - 1.0
geometry_change = geometry_change * 2.0 - 1.0
appearance_change = appearance_change * 2.0 - 1.0
resized_image = torch.clamp(resized_image, min=-1.0, max=1.0)
geometry_change = torch.clamp(geometry_change, min=-1.0, max=1.0)
appearance_change = torch.clamp(appearance_change, min=-1.0, max=1.0)
return resized_image, geometry_change, appearance_change
else:
image, label = self.data[index]
width, height = image.size
if width == self.resized_size:
resized_image = image
else:
resized_image = self.rescaler(image)
resized_image = self.to_tensor(resized_image)
resized_image = resized_image * 2.0 - 1.0
resized_image = torch.clamp(resized_image, min=-1.0, max=1.0)
return resized_image, int(label)