-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
223 lines (190 loc) · 8.76 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import numpy as np
import torch
from PIL import Image
import matplotlib.font_manager
import random
from PIL import ImageFont
from PIL import ImageDraw
from glob import glob
import os
from multiprocessing import Pool
from configs import seed
import argparse
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
class CutOutRectangles(object):
"""Cut out randomly rectangles from the image.
Args:
root_path (str): path to the folder where the images are stored.
num_rectangles (int): Number of rectangles to cut out.
max_h_size (int): Maximum height of the cut out rectangle.
max_w_size (int): Maximum width of the cut out rectangle.
"""
def __init__(
self,
root_path,
num_rectangles: int = 1,
max_h_size: int = 40,
max_w_size: int = 40
):
print("inside init of CutOutRectangles")
self.num_rectangles = num_rectangles
self.max_h_size = max_h_size
self.max_w_size = max_w_size
self.original = os.path.join(root_path, f'original_{max_h_size}px')
self.corrupted = os.path.join(root_path, f'corrupted_{max_h_size}px')
self.mask = os.path.join(root_path, f'mask_{max_h_size}px')
for p in [self.original, self.corrupted, self.mask]:
if not os.path.exists(p):
os.makedirs(p)
def __call__(self, original_path : str):
with Image.open(original_path) as original:
image = np.array(original)
mask = np.ones_like(image) * 255.
h, w = image.shape[:2]
# create the corners of the cutout rectangle
for i in range(self.num_rectangles):
y = torch.randint(0, h, (1, )).item()
x = torch.randint(0, w, (1, )).item()
y1 = np.clip(y - self.max_h_size // 2, 0, h)
y2 = np.clip(y1 + self.max_h_size, 0, h)
x1 = np.clip(x - self.max_w_size // 2, 0, w)
x2 = np.clip(x1 + self.max_w_size, 0, w)
# set the values in the recagle in the image to 0
image[y1:y2, x1:x2, :] = 0.
# using an RGB mode for the input image in acceptable because we want mask for each input channel
mask[y1:y2, x1:x2, :] = 0.
original.save(os.path.join(self.original, os.path.basename(original_path)))
Image.fromarray(image.astype(np.uint8)).save(os.path.join(self.corrupted, os.path.basename(original_path)))
Image.fromarray(mask.astype(np.uint8)).save(os.path.join(self.mask, os.path.basename(original_path)))
class RandomText(object):
"""Add random text on image .
Args:
text (str): Text to add.
text_size (int): Size of the text.
font (str): Font to use.
"""
def __init__(self, root_path, text_size: int, font=None):
self.root_path = root_path
self.text_size = text_size
self.font = font
if font is None:
self.fonts = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')
# mit words list file, the max word length is 22
# https://www.mit.edu/~ecprice/wordlist.10000
self.max_word_length = 22
self.max_word_length_text_size_rel = self.max_word_length * (self.text_size//2)
with open('mit-words.txt', 'r') as f:
self.words = f.read().splitlines()
self.original = os.path.join(root_path, f'original')
self.corrupted = os.path.join(root_path, f'corrupted')
self.mask = os.path.join(root_path, f'mask')
for p in [self.original, self.corrupted, self.mask]:
if not os.path.exists(p):
os.makedirs(p)
def __call__(self, original_path: str):
with Image.open(original_path) as original:
try:
font_name = random.choice(self.fonts)
# font = ImageFont.truetype("arial.ttf", self.text_size)
font = ImageFont.truetype(font_name, self.text_size)
except:
# font_name = random.choice(self.fonts)
font = ImageFont.truetype("arial.ttf", self.text_size)
image = original.copy()
image_draw = ImageDraw.Draw(image)
mask = Image.new(mode="RGB", size=image.size, color = 'white')
mask_draw = ImageDraw.Draw(mask)
# in PIL, size returns (width, height)
num_words = image.size[0] // self.text_size
words = " ".join(np.random.choice(self.words, num_words))
randomness_range = self.text_size + 5
# while we are drawing the text in coordiate smaller than the image's hight continue adding text
slack = 0
x = 0
y = 10
height = image.size[1]
slack_range = np.arange(self.text_size,randomness_range)
while(y <= height):
image_draw.text((x, y + slack), words, (0, 0, 0), font=font)
mask_draw.text((x, y + slack), words, (0, 0, 0), font=font)
slack = np.random.choice(slack_range, 1)[0]
y = y + slack
words = " ".join(np.random.choice(self.words, num_words))
font_name = random.choice(self.fonts)
original.save(os.path.join(self.original, os.path.basename(original_path)))
Image.fromarray(np.array(image).astype(np.uint8)).save(os.path.join(self.corrupted, os.path.basename(original_path)))
Image.fromarray(np.array(mask).astype(np.uint8)).save(os.path.join(self.mask, os.path.basename(original_path)))
def get_images_paths(root_dir, extensions=['jpg'], min_size=None, transform=None, nested=False):
"""
Args:
root_dir (string): Directory with all the images.
extensions (list of strings, optional): List of allowed image extensions.
min_size (int, optional): Minimum size of the image.
transform (callable, optional): Optional transform to be applied on a sample.
nested (bool, optional): if True, images are in a nested directory structure (only one level of nesting).
"""
root_dir = root_dir
transform = transform
# use glob to get a list of all images in the root_dir
# check the type of extensions, if it is a list, use it, otherwise make it a list
if not isinstance(extensions, list) and isinstance(extensions, str):
extensions = extensions.split(',')
images = []
search_pattern = '*.{}' if not nested else '**/*.{}'
for extension in extensions:
images.extend(glob(os.path.join(root_dir, search_pattern.format(extension))))
if (min_size is not None):
print("Filtering images based on the specified min_size...")
for img_p in images:
with Image.open(img_p) as img:
(width, height) = img.size
if (width < min_size or height < min_size):
images.remove(img_p)
print("Done.")
return images
def transform(args):
if (not args.custom):
datasets = [
{
'name': "1_cutout_large_50px",
'transform': 'cutout',
'parameters': {
'cutouts': 1,
'max_size': 50
}
},
{
'name': "random_text_15px",
'transform': 'random_text',
'parameters': {
'text_size': 15,
}
}
]
else:
raise Exception("No transform specified")
original_data_path = args.origin_data_path
base_directory = os.path.dirname(original_data_path)
images = get_images_paths(original_data_path, extensions=["png"], nested=True)
print(f"Found {len(images)} images")
print("Generating the datasets...")
for dataset in datasets:
dataset_path = os.path.join(base_directory, dataset['name'])
if not os.path.exists(dataset_path):
os.makedirs(dataset_path)
if (dataset['transform'] == 'cutout'):
transfomation = CutOutRectangles(dataset_path, num_rectangles=dataset['parameters']['cutouts'], max_h_size=dataset['parameters']['max_size'], max_w_size=dataset['parameters']['max_size'])
elif (dataset['transform'] == 'random_text'):
transfomation = RandomText(dataset_path, text_size=dataset['parameters']['text_size'])
print(f"\t{dataset['name']}")
with Pool(5) as p:
p.map(transfomation, images)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--origin-data-path', type=str, default=r"Filcker Faces thumbnails 128x128")
parser.add_argument('--custom', help="use the default setting for generating the dataset.", action="store_true", default=False)
args = parser.parse_args()
transform(args)