-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
52 lines (43 loc) · 2.07 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import tensorflow as tf
import numpy as np
def Dataloader(name, home_path):
if name == 'cifar10':
return Cifar10(home_path)
if name == 'cifar100':
return Cifar100(home_path)
def Cifar10(home_path):
from tensorflow.keras.datasets.cifar10 import load_data
(train_images, train_labels), (val_images, val_labels) = load_data()
def pre_processing(is_training = False):
def training(image, label):
image = tf.cast(image, tf.float32)
image = (image-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0])
image = tf.image.random_flip_left_right(image)
sz = tf.shape(image)
image = tf.pad(image, [[4,4],[4,4],[0,0]], 'REFLECT')
image = tf.image.random_crop(image,sz)
return image, label
def inference(image, label):
image = tf.cast(image, tf.float32)
image = (image-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0])
return image, label
return training if is_training else inference
return train_images, train_labels, val_images, val_labels, pre_processing
def Cifar100(home_path):
from tensorflow.keras.datasets.cifar100 import load_data
(train_images, train_labels), (val_images, val_labels) = load_data()
def pre_processing(is_training = False):
def training(image, label):
image = tf.cast(image, tf.float32)
image = (image-np.array([112,124,129]))/np.array([70,65,68])
image = tf.image.random_flip_left_right(image)
sz = tf.shape(image)
image = tf.pad(image, [[4,4],[4,4],[0,0]], 'REFLECT')
image = tf.image.random_crop(image,sz)
return image, label
def inference(image, label):
image = tf.cast(image, tf.float32)
image = (image-np.array([112,124,129]))/np.array([70,65,68])
return image, label
return training if is_training else inference
return train_images, train_labels, val_images, val_labels, pre_processing