-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_batch_loader.py
103 lines (85 loc) · 3.59 KB
/
image_batch_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import numpy as np
import cv2
class ImageBatchLoader():
"""Batch loader that returns randomly sampled images batches
Loader assumes all the images in the hr and lr directories
are named the same and that the directories are of equal size
Attributes:
directory (str): Path of top level data directory
hr_directory (str): Path of hr image directory
lr_directory (str): Path of lr image directory
batch_size (int): Size of batches returned
extension (str): Allowed extension for images
init_indices (arr): List of all the filenames in the hr (and lr) directory
indices (arr): List of filenames with currently sampled removed
hr (string): Path of hr directory relati
"""
def __init__(self, batch_size=16, directory='data', extension='.png', hr='/hr', lr='/lr'):
self.directory = f'{os.path.dirname(os.path.abspath(__file__))}/{directory}'
self.hr_directory = f"{self.directory}{hr}"
self.lr_directory = f"{self.directory}{lr}"
self.batch_size = batch_size
self.extension = extension
self.init_indices = self._get_indices()
self.indices = self.init_indices
self.set_size = len(self.init_indices)
def load_images(self, paths):
"""Returns an ndarray with image matrices
Arguments:
paths -- an array of relative paths to the data directory
"""
images = []
for path in paths:
full_path = f"{self.directory}/{path}"
image = cv2.imread(full_path)
images.append(self.normalize(image))
return np.array(images).astype(np.float32)
@staticmethod
def normalize(input_data):
"""Returns an ndarray with normalized image matrices
Arguments:
input_data -- ndarray of image matrices
"""
return (input_data.astype(np.float32) - 127.5) / 127.5
@staticmethod
def denormalize(input_data):
"""Returns an ndarray with denormalized image matrices
Arguments:
input_data -- ndarray of normalized image matrices
"""
input_data = (input_data + 1) * 127.5
return input_data.astype(np.uint8)
def next_batch(self):
"""Returns the next random hr and lr image batch"""
indices = self._get_random_batch()
hr_batch = self._load_images(indices, self.hr_directory)
lr_batch = self._load_images(indices, self.lr_directory)
return hr_batch, lr_batch
def reset(self):
"""Reset the indice list to the initial list of indices"""
self.indices = np.copy(self.init_indices)
assert len(self.indices) == len(self.init_indices)
def _get_indices(self):
# Returns a list with the filenames of the images
indices = []
for file in os.listdir(self.hr_directory):
if file.endswith(self.extension):
indices.append(int(file[:-4]))
return indices
def _get_random_batch(self):
# Randomly sample a list of indices from the pool
batch = []
for _ in range(0, self.batch_size):
index = np.random.choice(self.indices, size=1)
batch.append(index[0])
self.indices = np.delete(self.indices, index)
return batch
def _load_images(self, indices, directory):
# Return an ndarray with image matrices
images = []
for index in indices:
full_path = f"{directory}/{str(index).zfill(6)}.png"
image = cv2.imread(full_path)
images.append(self.normalize(image))
return np.array(images).astype(np.float32)