Skip to content

Commit

Permalink
Support for inputs with arbitrary numbers of channels
Browse files Browse the repository at this point in the history
  • Loading branch information
Callidior committed Dec 11, 2019
1 parent 3076f74 commit f2316f9
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 18 deletions.
14 changes: 14 additions & 0 deletions datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,13 @@ def num_test(self):
""" Number of test images in the dataset. """

return len(self.test_img_files)


@property
def num_channels(self):
""" Number of channels (e.g., 3 for RGB, 1 for grayscale). """

return 3



Expand Down Expand Up @@ -775,6 +782,13 @@ def num_test(self):
""" Number of test images in the dataset. """

return len(self.X_test)


@property
def num_channels(self):
""" Number of channels (e.g., 3 for RGB, 1 for grayscale). """

return self.X_train.shape[-1]



Expand Down
4 changes: 2 additions & 2 deletions learn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def transform_inputs(X, y, num_classes, label_smoothing = 0):
print('Resuming from snapshot {}'.format(args.snapshot))
model = keras.models.load_model(args.snapshot, custom_objects = utils.get_custom_objects(args.architecture), compile = False)
else:
model = utils.build_network(data_generator.num_classes, args.architecture, True)
model = utils.build_network(data_generator.num_classes, args.architecture, True, input_channels=data_generator.num_channels)
par_model = model if args.gpus <= 1 else keras.utils.multi_gpu_model(model, gpus = args.gpus, cpu_merge = False)
else:
with K.tf.device('/cpu:0'):
if args.snapshot and os.path.exists(args.snapshot):
print('Resuming from snapshot {}'.format(args.snapshot))
model = keras.models.load_model(args.snapshot, custom_objects = utils.get_custom_objects(args.architecture), compile = False)
else:
model = utils.build_network(data_generator.num_classes, args.architecture, True)
model = utils.build_network(data_generator.num_classes, args.architecture, True, input_channels=data_generator.num_channels)
par_model = keras.utils.multi_gpu_model(model, gpus = args.gpus)

if not args.no_progress:
Expand Down
4 changes: 2 additions & 2 deletions learn_image_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def transform_inputs(X, y, embedding, num_classes = None):
print('Resuming from snapshot {}'.format(args.snapshot))
model = keras.models.load_model(args.snapshot, custom_objects = utils.get_custom_objects(args.architecture), compile = False)
else:
embed_model = utils.build_network(embedding.shape[1], args.architecture)
embed_model = utils.build_network(embedding.shape[1], args.architecture, input_channels=data_generator.num_channels)
model = embed_model
if args.loss == 'inv_corr':
model = keras.models.Model(model.inputs, keras.layers.Lambda(utils.l2norm, name = 'l2norm')(model.output))
Expand All @@ -137,7 +137,7 @@ def transform_inputs(X, y, embedding, num_classes = None):
print('Resuming from snapshot {}'.format(args.snapshot))
model = keras.models.load_model(args.snapshot, custom_objects = utils.get_custom_objects(args.architecture), compile = False)
else:
embed_model = utils.build_network(embedding.shape[1], args.architecture)
embed_model = utils.build_network(embedding.shape[1], args.architecture, input_channels=data_generator.num_channels)
model = embed_model
if args.loss == 'inv_corr':
model = keras.models.Model(model.inputs, keras.layers.Lambda(utils.l2norm, name = 'l2norm')(model.output))
Expand Down
39 changes: 25 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def l2norm(x):
return K.tf.nn.l2_normalize(x, -1)


def build_network(num_outputs, architecture, classification = False, no_softmax = False, name = None):
def build_network(num_outputs, architecture, classification = False, no_softmax = False, input_channels = None, name = None):
""" Constructs a CNN.
# Arguments:
Expand All @@ -143,6 +143,8 @@ def build_network(num_outputs, architecture, classification = False, no_softmax
- no_softmax: Usually, the last layer will have a softmax activation if `classification` is True. However, if `no_softmax` is set
to True as well, the last layer will not have any activation.
- input_channels: Number of input channels.
- name: The name of the network.
# Returns:
Expand All @@ -154,68 +156,73 @@ def build_network(num_outputs, architecture, classification = False, no_softmax
architecture = architecture[:-5]
else:
activation = 'relu'

input_shape = None if input_channels is None else (None, None, input_channels)

# CIFAR-100 architectures

if architecture == 'resnet-32':

return cifar_resnet.SmallResNet(5, filters = [16, 32, 64], activation = activation,
return cifar_resnet.SmallResNet(5, filters = [16, 32, 64], activation = activation, input_shape = input_shape,
include_top = classification, top_activation = None if no_softmax else 'softmax',
classes = num_outputs, name = name)

elif architecture == 'resnet-110':

return cifar_resnet.SmallResNet(18, filters = [16, 32, 64], activation = activation,
return cifar_resnet.SmallResNet(18, filters = [16, 32, 64], activation = activation, input_shape = input_shape,
include_top = classification, top_activation = None if no_softmax else 'softmax',
classes = num_outputs, name = name)

elif architecture == 'resnet-110-fc':

return cifar_resnet.SmallResNet(18, filters = [16, 32, 64], activation = activation,
return cifar_resnet.SmallResNet(18, filters = [16, 32, 64], activation = activation, input_shape = input_shape,
include_top = True, top_activation = 'softmax' if classification and (not no_softmax) else None,
classes = num_outputs, name = name)

elif architecture == 'resnet-110-wfc':

return cifar_resnet.SmallResNet(18, filters = [32, 64, 128], activation = activation,
return cifar_resnet.SmallResNet(18, filters = [32, 64, 128], activation = activation, input_shape = input_shape,
include_top = True, top_activation = 'softmax' if classification and (not no_softmax) else None,
classes = num_outputs, name = name)

elif architecture == 'wrn-28-10':

return wrn.create_wide_residual_network((32, 32, 3), nb_classes = num_outputs, N = 4, k = 10, verbose = 0,
if input_channels is None:
input_channels = 3
return wrn.create_wide_residual_network((32, 32, input_channels), nb_classes = num_outputs, N = 4, k = 10, verbose = 0,
final_activation = 'softmax' if classification and (not no_softmax) else None, name = name)

elif architecture == 'densenet-100-12':

return densenet.DenseNet(growth_rate = 12, depth = 100, nb_dense_block = 3, bottleneck = False, nb_filter = 16, reduction = 0.0,
classes = num_outputs, activation = 'softmax' if classification and (not no_softmax) else None, name = name)
classes = num_outputs, activation = 'softmax' if classification and (not no_softmax) else None, input_shape = input_shape, name = name)

elif architecture == 'densenet-100-24':

return densenet.DenseNet(growth_rate = 24, depth = 100, nb_dense_block = 3, bottleneck = False, nb_filter = 16, reduction = 0.0,
classes = num_outputs, activation = 'softmax' if classification and (not no_softmax) else None, name = name)
classes = num_outputs, activation = 'softmax' if classification and (not no_softmax) else None, input_shape = input_shape, name = name)

elif architecture == 'densenet-bc-190-40':

return densenet.DenseNet(growth_rate = 40, depth = 190, nb_dense_block = 3, bottleneck = True, nb_filter = -1, reduction = 0.5,
classes = num_outputs, activation = 'softmax' if classification and (not no_softmax) else None, name = name)
classes = num_outputs, activation = 'softmax' if classification and (not no_softmax) else None, input_shape = input_shape, name = name)

elif architecture == 'pyramidnet-272-200':

return cifar_pyramidnet.PyramidNet(272, 200, bottleneck = True, activation = activation,
return cifar_pyramidnet.PyramidNet(272, 200, bottleneck = True, activation = activation, input_shape = input_shape,
classes = num_outputs, top_activation = 'softmax' if classification and (not no_softmax) else None, name = name)

elif architecture == 'pyramidnet-110-270':

return cifar_pyramidnet.PyramidNet(110, 270, bottleneck = False, activation = activation,
return cifar_pyramidnet.PyramidNet(110, 270, bottleneck = False, activation = activation, input_shape = input_shape,
classes = num_outputs, top_activation = 'softmax' if classification and (not no_softmax) else None, name = name)

elif architecture == 'simple':

return plainnet.PlainNet(num_outputs,
activation = activation,
final_activation = 'softmax' if classification and (not no_softmax) else None,
input_shape=input_shape if input_shape is not None else (None, None, 3),
name = name)

# ImageNet architectures
Expand All @@ -230,7 +237,7 @@ def build_network(num_outputs, architecture, classification = False, no_softmax
# ResNet50 has been available from the beginning, while the other two were added in keras-applications 1.0.7.
# Thus, we use the initial implementation of ResNet50 for compatibility's sake.
factory = keras.applications.ResNet50
rn = factory(include_top=False, weights=None)
rn = factory(include_top=False, weights=None, input_shape=input_shape)
# Depending on the Keras version, the ResNet50 model may or may not contain a final average pooling layer.
rn_out = rn.layers[-2].output if isinstance(rn.layers[-1], keras.layers.AveragePooling2D) else rn.layers[-1].output
x = keras.layers.GlobalAvgPool2D(name='avg_pool')(rn_out)
Expand All @@ -248,7 +255,9 @@ def build_network(num_outputs, architecture, classification = False, no_softmax
'rn152' : keras_resnet.models.ResNet152,
'rn200' : keras_resnet.models.ResNet200
}
input_ = keras.layers.Input((3, None, None)) if K.image_data_format() == 'channels_first' else keras.layers.Input((None, None, 3))
if input_channels is None:
input_channels = 3
input_ = keras.layers.Input((input_channels, None, None)) if K.image_data_format() == 'channels_first' else keras.layers.Input((None, None, input_channels))
rn = factories[architecture](input_, include_top = classification and (not no_softmax), classes = num_outputs, freeze_bn = False, name = name)
if (not classification) or no_softmax:
x = keras.layers.GlobalAvgPool2D(name = 'avg_pool')(rn.outputs[-1])
Expand All @@ -258,7 +267,9 @@ def build_network(num_outputs, architecture, classification = False, no_softmax

elif architecture == 'nasnet-a':

nasnet = keras.applications.NASNetLarge(include_top=False, input_shape=(224,224,3), weights=None, pooling='avg')
if input_channels is None:
input_channels = 3
nasnet = keras.applications.NASNetLarge(include_top=False, input_shape=(224, 224, input_channels), weights=None, pooling='avg')
x = keras.layers.Dense(num_outputs, activation = 'softmax' if classification and (not no_softmax) else None, name = 'prob' if classification else 'embedding')(nasnet.output)
return keras.models.Model(nasnet.inputs, x, name=name)

Expand Down

0 comments on commit f2316f9

Please sign in to comment.