From 9ae3704ff666c0e4ef98151d6c60cdabf8b67bb6 Mon Sep 17 00:00:00 2001 From: Hao Zhang <1339754545@qq.com> Date: Thu, 19 Sep 2019 14:03:41 +0800 Subject: [PATCH 1/4] Create efficientnet.py --- gluoncv/model_zoo/efficientnet.py | 766 ++++++++++++++++++++++++++++++ 1 file changed, 766 insertions(+) create mode 100644 gluoncv/model_zoo/efficientnet.py diff --git a/gluoncv/model_zoo/efficientnet.py b/gluoncv/model_zoo/efficientnet.py new file mode 100644 index 0000000000..b274003308 --- /dev/null +++ b/gluoncv/model_zoo/efficientnet.py @@ -0,0 +1,766 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable= arguments-differ,unused-argument,missing-docstring + +import math +import collections +import re +import mxnet as mx +from mxnet.gluon.block import Block +from mxnet.gluon import nn +# Parameters for the entire model (stem, all blocks, and head) + +__all__ = ['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', + 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', + 'efficientnet_b6', 'efficientnet_b7'] + + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) + + +def round_repeats(repeats, depth_coefficient=None): + """ Round number of filters based on depth multiplier. """ + multiplier = depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def round_filters(filters, width_coefficient=None, depth_divisor=None, min_depth=None): + """ Calculate and round number of filters based on depth multiplier. """ + multiplier = width_coefficient + if not multiplier: + return filters + divisor = depth_divisor + min_depth = min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max( + min_depth, int( + filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +class BlockDecoder(object): + """ Block Decoder for readability, straight from the official TensorFlow repository """ + + @staticmethod + def _decode_block_string(block_string): + """ Gets a block through a string notation of arguments. """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and + options['s'][0] == options['s'][1])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s'][0])) + + @staticmethod + def _encode_block_string(block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """ + Decodes a list of string notations to specify blocks inside the network. + :param string_list: a list of strings, each string is a notation of block + :return: a list of BlockArgs namedtuples of block args + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """ + Encodes a list of BlockArgs to a list of strings. + :param blocks_args: a list of BlockArgs namedtuples of block args + :return: a list of strings, each string is a notation of block + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_param(): + """ Creates a efficientnet model. """ + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + return blocks_args + + +class SamePadding(Block): + def __init__(self, kernel_size, stride, dilation, **kwargs): + super(SamePadding, self).__init__(**kwargs) + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 2 + if isinstance(stride, int): + stride = (stride,) * 2 + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def forward(self, F, x): + ih, iw = x.shape[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, mode='constant', pad_width=(0, 0, 0, 0, pad_w//2, pad_w -pad_w//2, + pad_h//2, pad_h - pad_h//2)) + return x + return x + + +def _add_conv(out, channels=1, kernel=1, stride=1, pad=0, + num_group=1, active=True, batchnorm=True): + out.add(SamePadding(kernel, stride, dilation=(1, 1))) + out.add(nn.Conv2D(channels, kernel, stride, pad, groups=num_group, use_bias=False)) + if batchnorm: + out.add(nn.BatchNorm(scale=True, momentum=0.99, epsilon=1e-3)) + if active: + out.add(nn.Swish()) + + +class MBConv(nn.Block): + def __init__(self, in_channels, channels, t, kernel, stride, se_ratio=0, + drop_connect_rate=0, **kwargs): + + r""" + Parameters + ---------- + int_channels: int, input channels. + channels: int, output channels. + t: int, the expand ratio used for increasing channels. + kernel: int, filter size. + stride: int, stride of the convolution. + se_ratio:int, ratio of the squeeze layer and excitation layer. + drop_connect_rate: int, drop rate of drop out. + """ + super(MBConv, self).__init__(**kwargs) + self.use_shortcut = stride == 1 and in_channels == channels + self.se_ratio = se_ratio + self.drop_connect_rate = drop_connect_rate + with self.name_scope(): + self.out = nn.Sequential(prefix="out_") + with self.out.name_scope(): + if t != 1: + _add_conv( + self.out, + in_channels * t, + active=True, + batchnorm=True) + _add_conv( + self.out, + in_channels * t, + kernel=kernel, + stride=stride, + num_group=in_channels * t, + active=True, + batchnorm=True) + if se_ratio: + num_squeezed_channels = max(1, int(in_channels * se_ratio)) + self._se_reduce = nn.Sequential(prefix="se_reduce_") + self._se_expand = nn.Sequential(prefix="se_expand_") + with self._se_reduce.name_scope(): + _add_conv( + self._se_reduce, + num_squeezed_channels, + active=False, + batchnorm=False) + with self._se_expand.name_scope(): + _add_conv( + self._se_expand, + in_channels * t, + active=False, + batchnorm=False) + self.project_layer = nn.Sequential(prefix="project_layer_") + with self.project_layer.name_scope(): + _add_conv( + self.project_layer, + channels, + active=False, + batchnorm=True) + if drop_connect_rate: + self.drop_out = nn.Dropout(drop_connect_rate) + + def forward(self, F, inputs): + x = inputs + x = self.out(x) + if self.se_ratio: + out = mx.nd.contrib.AdaptiveAvgPooling2D(x, 1) + out = self._se_expand(self._se_reduce(out)) + out = mx.ndarray.sigmoid(out) * x + out = self.project_layer(out) + if self.use_shortcut: + if self.drop_connect_rate: + out = self.drop_out(out) + out = F.elemwise_add(out, inputs) + return out + + +class EfficientNet(nn.Block): + + def __init__(self, blocks_args=None, + dropout_rate=None, + num_classes=None, + width_coefficient=None, + depth_cofficient=None, + depth_divisor=None, + min_depth=None, + drop_connect_rate=None, + **kwargs): + + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + blocks_args: nametuple, it concludes the hyperparameters of the MBConv block. + dropout_rate: float, rate of hidden units to drop. + num_classes: int, number of output classes. + width_coefficient:float, coefficient of the filters used for + expanding or reducing the channels. + depth_coefficient:float, it is used for repeat the EfficientNet Blocks. + depth_divisor:int , it is used for reducing the number of filters. + min_depth: int, used for deciding the minimum depth of the filters. + drop_connect_rate: used for dropout. + """ + super(EfficientNet, self).__init__(**kwargs) + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._blocks_args = blocks_args + self.input_size = None + with self.name_scope(): + self.features = nn.HybridSequential(prefix='features_') + with self.features.name_scope(): + # stem conv + out_channels = round_filters(32, + width_coefficient, + depth_divisor, + min_depth) + _add_conv( + self.features, + out_channels, + kernel=3, + stride=2, + active=True, + batchnorm=True) + self._blocks = nn.HybridSequential(prefix='blocks_') + with self._blocks.name_scope(): + for block_arg in self._blocks_args: + # Update block input and output filters based on depth + # multiplier. + block_arg = block_arg._replace( + input_filters=round_filters( + block_arg.input_filters, + width_coefficient, + depth_divisor, + min_depth), + output_filters=round_filters( + block_arg.output_filters, + width_coefficient, + depth_divisor, + min_depth), + num_repeat=round_repeats( + block_arg.num_repeat, depth_cofficient)) + self._blocks.add(MBConv(block_arg.input_filters, + block_arg.output_filters, + block_arg.expand_ratio, + block_arg.kernel_size, + block_arg.stride, + block_arg.se_ratio, + drop_connect_rate) + ) + if block_arg.num_repeat > 1: + block_arg = block_arg._replace( + input_filters=block_arg.output_filters, stride=1) + for _ in range(block_arg.num_repeat - 1): + self._blocks.add( + MBConv( + block_arg.input_filters, + block_arg.output_filters, + block_arg.expand_ratio, + block_arg.kernel_size, + block_arg.stride, + block_arg.se_ratio, + drop_connect_rate)) + + + + # Head + out_channels = round_filters(1280, width_coefficient, + depth_divisor, min_depth) + self._conv_head = nn.HybridSequential(prefix='conv_head_') + with self._conv_head.name_scope(): + _add_conv( + self._conv_head, + out_channels, + active=True, + batchnorm=True) + # Final linear layer + self._dropout = dropout_rate + self._fc = nn.Dense(num_classes, use_bias=False) + + def hybrid_forward(self, F, x): + x = self.features(x) + for block in self._blocks: + x = block(x) + x = self._conv_head(x) + x = F.squeeze(F.squeeze(mx.nd.contrib.AdaptiveAvgPooling2D(x, 1), axis=-1), axis=-1) + if self._dropout: + x = F.Dropout(x, self._dropout) + x = self._fc(x) + return x + + +def efficientnet(dropout_rate=None, + num_classes=None, + width_coefficient=None, + depth_coefficient=None, + depth_divisor=None, + min_depth=None, + drop_connect_rate=None): + + blocks_args = efficientnet_param() + model = EfficientNet(blocks_args, + dropout_rate, + num_classes, + width_coefficient, + depth_coefficient, + depth_divisor, min_depth, + drop_connect_rate) + return model + + +def efficientnet_b0(pretrained=False, + dropout_rate=0.2, + classes=1000, + width_coefficient=1.0, + depth_coefficient=1.0, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu() + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate) + model.collect_params().initialize(ctx=ctx) + return model + + +def efficientnet_b1(pretrained=False, + dropout_rate=0.2, + classes=1000, + width_coefficient=1.0, + depth_coefficient=1.1, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu(), + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate) + model.collect_params().initialize(ctx=ctx) + return model + + +def efficientnet_b2(pretrained=False, + dropout_rate=0.3, + classes=1000, + width_coefficient=1.1, + depth_coefficient=1.2, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu() + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate) + model.collect_params().initialize(ctx=ctx) + return model + + +def efficientnet_b3(pretrained=False, + dropout_rate=0.3, + classes=1000, + width_coefficient=1.2, + depth_coefficient=1.4, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu() + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate) + model.collect_params().initialize(ctx=ctx) + return model + + +def efficientnet_b4(pretrained=False, + dropout_rate=0.4, + classes=1000, + width_coefficient=1.4, + depth_coefficient=1.8, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu() + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate, + ) + model.collect_params().initialize(ctx=ctx) + return model + + +def efficientnet_b5(pretrained=False, + dropout_rate=0.4, + classes=1000, + width_coefficient=1.6, + depth_coefficient=2.2, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu(), + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate) + model.collect_params().initialize(ctx=ctx) + return model + + +def efficientnet_b6(pretrained=False, + dropout_rate=0.5, + classes=1000, + width_coefficient=1.8, + depth_coefficient=2.6, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu() + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate) + model.collect_params().initialize(ctx=ctx) + return model + + +def efficientnet_b7(pretrained=False, + dropout_rate=0.5, + classes=1000, + width_coefficient=2.0, + depth_coefficient=3.1, + depth_divisor=8, + min_depth=None, + drop_connect_rate=0.2, + ctx=mx.cpu() + ): + r"""EfficientNet model from the + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" + `_ paper. + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + dropout_rate : float + Rate of hidden units to drop. + classes : int, number of output classes. + width_coefficient : float + Coefficient of the filters. + Used for expanding or reducing the channels. + depth_coefficient : float + It is used for repeat the EfficientNet Blocks. + depth_divisor:int + It is used for reducing the number of filters. + min_depth : int + Used for deciding the minimum depth of the filters. + drop_connect_rate : float + Used for dropout. + """ + if pretrained: + pass + model = efficientnet(dropout_rate, + classes, + width_coefficient, + depth_coefficient, + depth_divisor, + min_depth, + drop_connect_rate, + ) + model.collect_params().initialize(ctx=ctx) + return model From 5394cca0897447f4ab6a5ba8225afc790dfc3e60 Mon Sep 17 00:00:00 2001 From: Hao Zhang <1339754545@qq.com> Date: Thu, 19 Sep 2019 14:05:03 +0800 Subject: [PATCH 2/4] Update model_zoo.py --- gluoncv/model_zoo/model_zoo.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/gluoncv/model_zoo/model_zoo.py b/gluoncv/model_zoo/model_zoo.py index dbd7882fe9..dc8129b416 100644 --- a/gluoncv/model_zoo/model_zoo.py +++ b/gluoncv/model_zoo/model_zoo.py @@ -32,6 +32,7 @@ from .yolo import * from .alpha_pose import * from .action_recognition import * +from .efficientnet import * __all__ = ['get_model', 'get_model_list'] @@ -230,7 +231,15 @@ 'psp_resnet101_voc_int8': psp_resnet101_voc_int8, 'psp_resnet101_coco_int8': psp_resnet101_coco_int8, 'deeplab_resnet101_voc_int8': deeplab_resnet101_voc_int8, - 'deeplab_resnet101_coco_int8': deeplab_resnet101_coco_int8 + 'deeplab_resnet101_coco_int8': deeplab_resnet101_coco_int8, + 'efficientnet_b0': efficientnet_b0, + 'efficientnet_b1': efficientnet_b1, + 'efficientnet_b2': efficientnet_b2, + 'efficientnet_b3': efficientnet_b3, + 'efficientnet_b4': efficientnet_b4, + 'efficientnet_b5': efficientnet_b5, + 'efficientnet_b6': efficientnet_b6, + 'efficientnet_b7': efficientnet_b7, } From 64a55a21ffccd99b60c4436b79e3e60a8379f0f3 Mon Sep 17 00:00:00 2001 From: Hao Zhang <1339754545@qq.com> Date: Thu, 19 Sep 2019 14:05:45 +0800 Subject: [PATCH 3/4] Update __init__.py --- gluoncv/model_zoo/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gluoncv/model_zoo/__init__.py b/gluoncv/model_zoo/__init__.py index ffe5a2ff57..b6052567c0 100644 --- a/gluoncv/model_zoo/__init__.py +++ b/gluoncv/model_zoo/__init__.py @@ -30,3 +30,4 @@ from .vgg import * from .mobilenet import * from .residual_attentionnet import * +from .efficientnet import * From 6d7b0c68727d7d3477d1f1bc89c0f8414c91098f Mon Sep 17 00:00:00 2001 From: Hao Zhang <1339754545@qq.com> Date: Thu, 19 Sep 2019 14:28:51 +0800 Subject: [PATCH 4/4] Update efficientnet.py --- gluoncv/model_zoo/efficientnet.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/gluoncv/model_zoo/efficientnet.py b/gluoncv/model_zoo/efficientnet.py index b274003308..dcf5958915 100644 --- a/gluoncv/model_zoo/efficientnet.py +++ b/gluoncv/model_zoo/efficientnet.py @@ -277,6 +277,7 @@ def __init__(self, blocks_args=None, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- blocks_args: nametuple, it concludes the hyperparameters of the MBConv block. @@ -288,6 +289,7 @@ def __init__(self, blocks_args=None, depth_divisor:int , it is used for reducing the number of filters. min_depth: int, used for deciding the minimum depth of the filters. drop_connect_rate: used for dropout. + """ super(EfficientNet, self).__init__(**kwargs) assert isinstance(blocks_args, list), 'blocks_args should be a list' @@ -295,7 +297,7 @@ def __init__(self, blocks_args=None, self._blocks_args = blocks_args self.input_size = None with self.name_scope(): - self.features = nn.HybridSequential(prefix='features_') + self.features = nn.Sequential(prefix='features_') with self.features.name_scope(): # stem conv out_channels = round_filters(32, @@ -309,7 +311,7 @@ def __init__(self, blocks_args=None, stride=2, active=True, batchnorm=True) - self._blocks = nn.HybridSequential(prefix='blocks_') + self._blocks = nn.Sequential(prefix='blocks_') with self._blocks.name_scope(): for block_arg in self._blocks_args: # Update block input and output filters based on depth @@ -354,7 +356,7 @@ def __init__(self, blocks_args=None, # Head out_channels = round_filters(1280, width_coefficient, depth_divisor, min_depth) - self._conv_head = nn.HybridSequential(prefix='conv_head_') + self._conv_head = nn.Sequential(prefix='conv_head_') with self._conv_head.name_scope(): _add_conv( self._conv_head, @@ -365,7 +367,7 @@ def __init__(self, blocks_args=None, self._dropout = dropout_rate self._fc = nn.Dense(num_classes, use_bias=False) - def hybrid_forward(self, F, x): + def forward(self, F, x): x = self.features(x) for block in self._blocks: x = block(x) @@ -409,6 +411,7 @@ def efficientnet_b0(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -428,6 +431,7 @@ def efficientnet_b0(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass @@ -455,6 +459,7 @@ def efficientnet_b1(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -474,6 +479,7 @@ def efficientnet_b1(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass @@ -501,6 +507,7 @@ def efficientnet_b2(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -520,6 +527,7 @@ def efficientnet_b2(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass @@ -547,6 +555,7 @@ def efficientnet_b3(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -566,6 +575,7 @@ def efficientnet_b3(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass @@ -593,6 +603,7 @@ def efficientnet_b4(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -612,6 +623,7 @@ def efficientnet_b4(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass @@ -640,6 +652,7 @@ def efficientnet_b5(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -659,6 +672,7 @@ def efficientnet_b5(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass @@ -686,6 +700,7 @@ def efficientnet_b6(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -705,6 +720,7 @@ def efficientnet_b6(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass @@ -732,6 +748,7 @@ def efficientnet_b7(pretrained=False, r"""EfficientNet model from the `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_ paper. + Parameters ---------- pretrained : bool or str @@ -751,6 +768,7 @@ def efficientnet_b7(pretrained=False, Used for deciding the minimum depth of the filters. drop_connect_rate : float Used for dropout. + """ if pretrained: pass