-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataloader.py
89 lines (72 loc) · 2.97 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import tensorflow as tf
seed: int = 13371337
# reproducibility
np.random.seed(seed)
tf.set_random_seed(seed)
class ImageDataLoader:
def __init__(self,
patch_shape: tuple = (128, 128),
channels: int = 3,
n_patches: int = 16):
self.patch_shape = patch_shape
self.channels = channels
self.n_patches = n_patches
self.scale = int(np.sqrt(self.n_patches))
self.lr_patch_shape = (
self.patch_shape[0],
self.patch_shape[1])
self.hr_patch_shape = (
self.patch_shape[0] * self.scale,
self.patch_shape[1] * self.scale
)
@staticmethod
def normalize(x):
return (x / 127.5) - 1.
def random_crop(self, x_lr, x_hr):
x_hr_shape = x_hr.get_shape().as_list()
rand_lr_w = (np.random.randint(0, x_hr_shape[0] - self.hr_patch_shape[0])
// self.scale)
rand_lr_h = (np.random.randint(0, x_hr_shape[1] - self.hr_patch_shape[1])
// self.scale)
rand_hr_w = rand_lr_w * self.scale
rand_hr_h = rand_lr_h * self.scale
x_lr = x_lr[rand_lr_w:rand_lr_w + self.lr_patch_shape[0], rand_lr_h:rand_lr_h + self.lr_patch_shape[1], :]
x_hr = x_hr[rand_hr_w:rand_hr_w + self.hr_patch_shape[0], rand_hr_h:rand_hr_h + self.hr_patch_shape[1], :]
return x_lr, x_hr
def pre_processing(self, fn, use_augmentation: bool = True):
lr = tf.read_file(fn[0])
lr = tf.image.decode_png(lr, channels=self.channels)
lr = self.normalize(tf.cast(lr, dtype=tf.float32))
hr = tf.read_file(fn[1])
hr = tf.image.decode_png(hr, channels=self.channels)
hr = self.normalize(tf.cast(hr, dtype=tf.float32))
# random crop
lr, hr = self.random_crop(lr, hr)
if use_augmentation:
if np.random.randint(0, 2) == 0:
lr = tf.image.flip_up_down(lr)
hr = tf.image.flip_up_down(hr)
if np.random.randint(0, 2) == 0:
lr = tf.image.rot90(lr)
hr = tf.image.rot90(hr)
# split into patches
lr_patches = tf.image.extract_image_patches(
images=tf.expand_dims(lr, axis=0),
ksizes=(1,) + self.lr_patch_shape + (1,),
strides=(1,) + self.lr_patch_shape + (1,),
rates=[1, 1, 1, 1],
padding='VALID'
)
lr_patches = tf.reshape(lr_patches,
(-1,) + self.lr_patch_shape + (self.channels,))
hr_patches = tf.image.extract_image_patches(
images=tf.expand_dims(hr, axis=0),
ksizes=(1,) + self.hr_patch_shape + (1,),
strides=(1,) + self.hr_patch_shape + (1,),
rates=[1, 1, 1, 1],
padding='VALID'
)
hr_patches = tf.reshape(hr_patches,
(-1,) + self.hr_patch_shape + (self.channels,))
return lr_patches, hr_patches