From 5921740d9f69d853fe543c52811b590a398031f5 Mon Sep 17 00:00:00 2001 From: Hang Zhang <8041160+zhanghang1989@users.noreply.github.com> Date: Mon, 25 Jun 2018 22:42:29 -0600 Subject: [PATCH] Add PSPNet (#179) * psp --- docs/api/model_zoo.rst | 18 +- docs/model_zoo/index.rst | 17 +- docs/tutorials/segmentation/demo_fcn.py | 2 +- docs/tutorials/segmentation/demo_psp.py | 65 +++++ docs/tutorials/segmentation/train_fcn.py | 6 +- docs/tutorials/segmentation/train_psp.py | 231 +++++++++++++++++ gluoncv/model_zoo/fcn.py | 3 +- gluoncv/model_zoo/model_store.py | 1 + gluoncv/model_zoo/model_zoo.py | 2 + gluoncv/model_zoo/pspnet.py | 10 +- gluoncv/model_zoo/resnetv1b.py | 20 +- gluoncv/model_zoo/syncbn.py | 294 +++++++++++----------- gluoncv/utils/metrics/voc_segmentation.py | 10 +- gluoncv/utils/parallel.py | 18 +- scripts/segmentation/test.py | 2 +- 15 files changed, 495 insertions(+), 204 deletions(-) create mode 100644 docs/tutorials/segmentation/demo_psp.py create mode 100644 docs/tutorials/segmentation/train_psp.py diff --git a/docs/api/model_zoo.rst b/docs/api/model_zoo.rst index 74987dce40..87a721ccc0 100644 --- a/docs/api/model_zoo.rst +++ b/docs/api/model_zoo.rst @@ -119,18 +119,10 @@ Object Detection faster_rcnn_resnet50_v2a_coco -.. currentmodule:: gluoncv.model_zoo - Semantic Segmentation ^^^^^^^^^^^^^^^^^^^^^ -.. :hidden:`BaseModel` -.. ~~~~~~~~~~~~~~~~~~~ - -.. .. autosummary:: -.. :nosignatures: - - .. segbase.SegBaseModel +.. currentmodule:: gluoncv.model_zoo :hidden:`FCN` ~~~~~~~~~~~~~ @@ -148,11 +140,17 @@ Semantic Segmentation get_fcn_ade_resnet50 +:hidden:`PSPNet` +~~~~~~~~~~~~~~~~ +.. autosummary:: + :nosignatures: + PSPNet + get_psp - + get_psp_ade_resnet50 API Reference diff --git a/docs/model_zoo/index.rst b/docs/model_zoo/index.rst index 96fd96b96b..4023bf66c3 100644 --- a/docs/model_zoo/index.rst +++ b/docs/model_zoo/index.rst @@ -231,25 +231,12 @@ Table of pre-trained models for semantic segmentation and their performance. +-------------------+--------------+-----------+-----------+-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ | fcn_resnet50_ade | FCN [6]_ | 78.6 | 38.7 | `shell script `_ | `log `_ | +-------------------+--------------+-----------+-----------+-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ +| psp_resnet50_ade | PSP [9]_ | 78.4 | 41.1 | `shell script `_ | `log `_ | ++-------------------+--------------+-----------+-----------+-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+ .. _69.4: http://host.robots.ox.ac.uk:8080/anonymous/TC12D2.html .. _70.9: http://host.robots.ox.ac.uk:8080/anonymous/FTIQXJ.html -.. raw:: html - - - - .. [1] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. \ "Deep residual learning for image recognition." \ diff --git a/docs/tutorials/segmentation/demo_fcn.py b/docs/tutorials/segmentation/demo_fcn.py index 063cea9a90..2ce990e813 100644 --- a/docs/tutorials/segmentation/demo_fcn.py +++ b/docs/tutorials/segmentation/demo_fcn.py @@ -1,7 +1,7 @@ """1. Getting Started with FCN Pre-trained Models ============================================== -This is a quick demo of using GluonCV FCN model. +This is a quick demo of using GluonCV FCN model on PASCAL VOC dataset. Please follow the `installation guide <../index.html>`_ to install MXNet and GluonCV if not yet. """ import mxnet as mx diff --git a/docs/tutorials/segmentation/demo_psp.py b/docs/tutorials/segmentation/demo_psp.py new file mode 100644 index 0000000000..849137df74 --- /dev/null +++ b/docs/tutorials/segmentation/demo_psp.py @@ -0,0 +1,65 @@ +"""2. Test with PSPNet Pre-trained Models +====================================== + +This is a quick demo of using GluonCV PSPNet model on ADE20K dataset. +Please follow the `installation guide <../index.html>`_ to install MXNet and GluonCV if not yet. +""" +import mxnet as mx +from mxnet import image +from mxnet.gluon.data.vision import transforms +import gluoncv +# using cpu +ctx = mx.cpu(0) + + +############################################################################## +# Prepare the image +# ----------------- +# +# download the example image +url = 'https://github.com/zhanghang1989/image-data/blob/master/encoding/' + \ + 'segmentation/ade20k/ADE_val_00001142.jpg?raw=true' +filename = 'ade20k_example.jpg' +gluoncv.utils.download(url, filename) + +############################################################################## +# load the image +img = image.imread(filename) + +from matplotlib import pyplot as plt +plt.imshow(img.asnumpy()) +plt.show() + +############################################################################## +# normalize the image using dataset mean +transform_fn = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([.485, .456, .406], [.229, .224, .225]) +]) +img = transform_fn(img) +img = img.expand_dims(0).as_in_context(ctx) + +############################################################################## +# Load the pre-trained model and make prediction +# ---------------------------------------------- +# +# get pre-trained model +model = gluoncv.model_zoo.get_model('psp_resnet50_ade', pretrained=True) + +############################################################################## +# make prediction using single scale +output = model.demo(img) +predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy() + +############################################################################## +# Add color pallete for visualization +from gluoncv.utils.viz import get_color_pallete +import matplotlib.image as mpimg +mask = get_color_pallete(predict, 'ade20k') +mask.save('output.png') + +############################################################################## +# show the predicted mask +mmask = mpimg.imread('output.png') +plt.imshow(mmask) +plt.show() diff --git a/docs/tutorials/segmentation/train_fcn.py b/docs/tutorials/segmentation/train_fcn.py index 1a6747a44f..00ca2140bf 100644 --- a/docs/tutorials/segmentation/train_fcn.py +++ b/docs/tutorials/segmentation/train_fcn.py @@ -1,4 +1,4 @@ -"""2. Train FCN on Pascal VOC Dataset +"""3. Train FCN on Pascal VOC Dataset ===================================== This is a semantic segmentation tutorial using Gluon Vison, a step-by-step example. @@ -133,8 +133,8 @@ # For example, we can easily get the Pascal VOC 2012 dataset: trainset = gluoncv.data.VOCSegmentation(split='train', transform=input_transform) print('Training images:', len(trainset)) -# set batch_size = 4 for toy example -batch_size = 4 +# set batch_size = 2 for toy example +batch_size = 2 # Create Training Loader train_data = gluon.data.DataLoader( trainset, batch_size, shuffle=True, last_batch='rollover', diff --git a/docs/tutorials/segmentation/train_psp.py b/docs/tutorials/segmentation/train_psp.py new file mode 100644 index 0000000000..b6eb4db2d0 --- /dev/null +++ b/docs/tutorials/segmentation/train_psp.py @@ -0,0 +1,231 @@ +"""3. Train PSPNet on ADE20K Dataset +================================= + +This is a tutorial of training PSPNet on ADE20K dataset using Gluon Vison. +The readers should have basic knowledge of deep learning and should be familiar with Gluon API. +New users may first go through `A 60-minute Gluon Crash Course `_. +You can `Start Training Now`_ or `Dive into Deep`_. + +Start Training Now +~~~~~~~~~~~~~~~~~~ + +.. note:: + + Training PSPNet relies on Synchronized Batch Normalization, which will be available shortly. + +.. hint:: + + Feel free to skip the tutorial because the training script is self-complete and ready to launch. + + :download:`Download Full Python Script: train.py<../../../scripts/segmentation/train.py>` + + Example training command:: + + CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ade20k --model psp --backbone resnet50 --lr 0.001 --checkname mycheckpoint + + For more training command options, please run ``python train.py -h`` + Please checkout the `model_zoo <../model_zoo/index.html#semantic-segmentation>`_ for training commands of reproducing the pretrained model. + +Dive into Deep +~~~~~~~~~~~~~~ +""" +import numpy as np +import mxnet as mx +from mxnet import gluon, autograd +import gluoncv + +############################################################################## +# Pyramid Scene Parsing Network +# ----------------------------- +# +# .. image:: https://hszhao.github.io/projects/pspnet/figures/pspnet.png +# :width: 80% +# :align: center +# +# (figure credit to `Zhao et al. `_ ) +# +# Pyramid Scene Parsing Network (PSPNet) [Zhao17]_ exploit the +# capability of global context information by different-regionbased +# context aggregation through the pyramid pooling module. +# + + +############################################################################## +# PSPNet Model +# ------------ +# +# A Pyramid Pooling Module is built on top of FCN, which combines multiple scale +# features with different receptive field sizes. It pools the featuremaps +# into different sizes and then concatinating together after upsampling. +# +# The Pyramid Pooling Module is defined as:: +# +# class _PyramidPooling(HybridBlock): +# def __init__(self, in_channels, **kwargs): +# super(_PyramidPooling, self).__init__() +# out_channels = int(in_channels/4) +# with self.name_scope(): +# self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs) +# self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs) +# self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs) +# self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs) +# +# def pool(self, F, x, size): +# return F.contrib.AdaptiveAvgPooling2D(x, output_size=size) +# +# def upsample(self, F, x, h, w): +# return F.contrib.BilinearResize2D(x, height=h, width=w) +# +# def hybrid_forward(self, F, x): +# _, _, h, w = x.shape +# feat1 = self.upsample(F, self.conv1(self.pool(F, x, 1)), h, w) +# feat2 = self.upsample(F, self.conv2(self.pool(F, x, 2)), h, w) +# feat3 = self.upsample(F, self.conv3(self.pool(F, x, 3)), h, w) +# feat4 = self.upsample(F, self.conv4(self.pool(F, x, 4)), h, w) +# return F.concat(x, feat1, feat2, feat3, feat4, dim=1) +# +# PSPNet model is provided in :class:`gluoncv.model_zoo.PSPNet`. To get +# PSP model using ResNet50 base network for ADE20K dataset: +model = gluoncv.model_zoo.get_psp(dataset='ade20k', backbone='resnet50', pretrained=False) +print(model) + +############################################################################## +# Dataset and Data Augmentation +# ----------------------------- +# +# image transform for color normalization +from mxnet.gluon.data.vision import transforms +input_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([.485, .456, .406], [.229, .224, .225]), +]) + +############################################################################## +# We provide semantic segmentation datasets in :class:`gluoncv.data`. +# For example, we can easily get the ADE20K dataset: +trainset = gluoncv.data.ADE20KSegmentation(split='train', transform=input_transform) +print('Training images:', len(trainset)) +# set batch_size = 2 for toy example +batch_size = 2 +# Create Training Loader +train_data = gluon.data.DataLoader( + trainset, batch_size, shuffle=True, last_batch='rollover', + num_workers=batch_size) + +############################################################################## +# For data augmentation, +# we follow the standard data augmentation routine to transform the input image +# and the ground truth label map synchronously. (*Note that "nearest" +# mode upsample are applied to the label maps to avoid messing up the boundaries.*) +# We first randomly scale the input image from 0.5 to 2.0 times, then rotate +# the image from -10 to 10 degrees, and crop the image with padding if needed. +# Finally a random Gaussian blurring is applied. +# +# Random pick one example for visualization: +import random +from datetime import datetime +random.seed(datetime.now()) +idx = random.randint(0, len(trainset)) +img, mask = trainset[idx] +from gluoncv.utils.viz import get_color_pallete, DeNormalize +# get color pallete for visualize mask +mask = get_color_pallete(mask.asnumpy(), dataset='ade20k') +mask.save('mask.png') +# denormalize the image +img = DeNormalize([.485, .456, .406], [.229, .224, .225])(img) +img = np.transpose((img.asnumpy()*255).astype(np.uint8), (1, 2, 0)) + +############################################################################## +# Plot the image and mask +from matplotlib import pyplot as plt +import matplotlib.image as mpimg +# subplot 1 for img +fig = plt.figure() +fig.add_subplot(1,2,1) + +plt.imshow(img) +# subplot 2 for the mask +mmask = mpimg.imread('mask.png') +fig.add_subplot(1,2,2) +plt.imshow(mmask) +# display +plt.show() + +############################################################################## +# Training Details +# ---------------- +# +# - Training Losses: +# +# We apply a standard per-pixel Softmax Cross Entropy Loss to train PSPNet. +# Additionally, an Auxiliary Loss as in PSPNet [Zhao17]_ at Stage 3 can be enabled when +# training with command ``--aux``. This will create an additional FCN "head" after Stage 3. +# +from gluoncv.model_zoo.segbase import SoftmaxCrossEntropyLossWithAux +criterion = SoftmaxCrossEntropyLossWithAux(aux=True) + +############################################################################## +# - Learning Rate and Scheduling: +# +# We use different learning rate for PSP "head" and the base network. For the PSP "head", +# we use :math:`10\times` base learning rate, because those layers are learned from scratch. +# We use a poly-like learning rate scheduler for FCN training, provided in :class:`gluoncv.utils.LRScheduler`. +# The learning rate is given by :math:`lr = baselr \times (1-iter)^{power}` +# +lr_scheduler = gluoncv.utils.LRScheduler(mode='poly', baselr=0.001, niters=len(train_data), + nepochs=50) + +############################################################################## +# - Dataparallel for multi-gpu training, using cpu for demo only +from gluoncv.utils.parallel import * +ctx_list = [mx.cpu(0)] +model = DataParallelModel(model, ctx_list) +criterion = DataParallelCriterion(criterion, ctx_list) + +############################################################################## +# - Create SGD solver +kv = mx.kv.create('local') +optimizer = gluon.Trainer(model.module.collect_params(), 'sgd', + {'lr_scheduler': lr_scheduler, + 'wd':0.0001, + 'momentum': 0.9, + 'multi_precision': True}, + kvstore = kv) + +############################################################################## +# The training loop +# ----------------- +# +train_loss = 0.0 +epoch = 0 +for i, (data, target) in enumerate(train_data): + lr_scheduler.update(i, epoch) + with autograd.record(True): + outputs = model(data) + losses = criterion(outputs, target) + mx.nd.waitall() + autograd.backward(losses) + optimizer.step(batch_size) + for loss in losses: + train_loss += loss.asnumpy()[0] / len(losses) + print('Epoch %d, batch %d, training loss %.3f'%(epoch, i, train_loss/(i+1))) + # just demo for 2 iters + if i > 1: + print('Terminated for this demo...') + break + + +############################################################################## +# You can `Start Training Now`_. +# +# References +# ---------- +# +# .. [Long15] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. \ +# "Fully convolutional networks for semantic segmentation." \ +# Proceedings of the IEEE conference on computer vision and pattern recognition. 2015. +# +# .. [Zhao17] Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. \ +# "Pyramid scene parsing network." IEEE Conf. on Computer Vision and Pattern Recognition (CVPR). 2017. +# + diff --git a/gluoncv/model_zoo/fcn.py b/gluoncv/model_zoo/fcn.py index b5a153811b..06ebcc1dcd 100644 --- a/gluoncv/model_zoo/fcn.py +++ b/gluoncv/model_zoo/fcn.py @@ -36,7 +36,8 @@ class FCN(SegBaseModel): # pylint: disable=arguments-differ def __init__(self, nclass, backbone='resnet50', norm_layer=nn.BatchNorm, aux=True, ctx=cpu(), **kwargs): - super(FCN, self).__init__(nclass, aux, backbone, ctx=ctx, norm_layer=norm_layer, **kwargs) + super(FCN, self).__init__(nclass, aux, backbone, ctx=ctx, + norm_layer=norm_layer, **kwargs) with self.name_scope(): self.head = _FCNHead(2048, nclass, norm_layer=norm_layer, **kwargs) self.head.initialize(ctx=ctx) diff --git a/gluoncv/model_zoo/model_store.py b/gluoncv/model_zoo/model_store.py index cb3e2ac0e7..df84b027ae 100644 --- a/gluoncv/model_zoo/model_store.py +++ b/gluoncv/model_zoo/model_store.py @@ -37,6 +37,7 @@ ('953657f235cc52dbc60f3874f9d437c380045cd0', 'fcn_resnet50_voc'), ('70a6f22a1a0b6ddd1f680de587d67b5c2c0acc0b', 'fcn_resnet101_voc'), ('b1b11976bf753ed1e05a526065e1666950fcf0a2', 'fcn_resnet50_ade'), + ('3133bd42540ffee84a54e95468de74560a4a27b9', 'psp_resnet50_ade'), ('b8f9c5f193d84b828fa78d7bd7646ca7f11a29bc', 'resnet50_v2a'), ]} diff --git a/gluoncv/model_zoo/model_zoo.py b/gluoncv/model_zoo/model_zoo.py index a8b47271e8..0221a630ea 100644 --- a/gluoncv/model_zoo/model_zoo.py +++ b/gluoncv/model_zoo/model_zoo.py @@ -6,6 +6,7 @@ from .ssd import * from .faster_rcnn import * from .fcn import * +from .pspnet import * from .cifarresnet import * from .cifarresnext import * from .cifarwideresnet import * @@ -65,6 +66,7 @@ def get_model(name, **kwargs): 'fcn_resnet50_voc' : get_fcn_voc_resnet50, 'fcn_resnet101_voc' : get_fcn_voc_resnet101, 'fcn_resnet50_ade' : get_fcn_ade_resnet50, + 'psp_resnet50_ade' : get_psp_ade_resnet50, 'resnet18_v1b' : resnet18_v1b, 'resnet34_v1b' : resnet34_v1b, 'resnet50_v1b' : resnet50_v1b, diff --git a/gluoncv/model_zoo/pspnet.py b/gluoncv/model_zoo/pspnet.py index 60bc6990ea..6213c40b46 100644 --- a/gluoncv/model_zoo/pspnet.py +++ b/gluoncv/model_zoo/pspnet.py @@ -36,11 +36,11 @@ def __init__(self, nclass, backbone='resnet50', norm_layer=nn.BatchNorm, norm_layer=norm_layer, **kwargs) with self.name_scope(): self.head = _PSPHead(nclass, norm_layer=norm_layer, **kwargs) - self.head.initialize() + self.head.initialize(ctx=ctx) self.head.collect_params().setattr('lr_mult', 10) if self.aux: self.auxlayer = _FCNHead(1024, nclass, norm_layer=norm_layer, **kwargs) - self.auxlayer.initialize() + self.auxlayer.initialize(ctx=ctx) self.auxlayer.collect_params().setattr('lr_mult', 10) def hybrid_forward(self, F, x): @@ -62,10 +62,10 @@ def hybrid_forward(self, F, x): def _PSP1x1Conv(in_channels, out_channels, norm_layer=None, **kwargs): block = nn.HybridSequential(prefix='') with block.name_scope(): - block.add(norm_layer(in_channels=in_channels)) - block.add(nn.Activation('relu')) block.add(nn.Conv2D(in_channels=in_channels, channels=out_channels, kernel_size=1)) + block.add(norm_layer(in_channels=out_channels)) + block.add(nn.Activation('relu')) return block @@ -100,8 +100,6 @@ def __init__(self, nclass, norm_layer=None, **kwargs): self.psp = _PyramidPooling(2048, norm_layer=norm_layer, **kwargs) with self.name_scope(): self.block = nn.HybridSequential(prefix='') - self.block.add(norm_layer(in_channels=4096)) - self.block.add(nn.Activation('relu')) self.block.add(nn.Conv2D(in_channels=4096, channels=512, kernel_size=3, padding=1)) self.block.add(norm_layer(in_channels=512)) diff --git a/gluoncv/model_zoo/resnetv1b.py b/gluoncv/model_zoo/resnetv1b.py index d70094e309..1f95883dc4 100644 --- a/gluoncv/model_zoo/resnetv1b.py +++ b/gluoncv/model_zoo/resnetv1b.py @@ -22,12 +22,12 @@ def __init__(self, inplanes, planes, strides=1, dilation=1, downsample=None, self.conv1 = nn.Conv2D(in_channels=inplanes, channels=planes, kernel_size=3, strides=strides, padding=dilation, dilation=dilation, use_bias=False) - self.bn1 = nn.BatchNorm(in_channels=planes) + self.bn1 = norm_layer(in_channels=planes) self.relu = nn.Activation('relu') self.conv2 = nn.Conv2D(in_channels=planes, channels=planes, kernel_size=3, strides=1, padding=previous_dilation, dilation=previous_dilation, use_bias=False) - self.bn2 = nn.BatchNorm(in_channels=planes) + self.bn2 = norm_layer(in_channels=planes) self.downsample = downsample self.strides = strides @@ -60,11 +60,11 @@ def __init__(self, inplanes, planes, strides=1, dilation=1, last_gamma=False, **kwargs): super(BottleneckV1b, self).__init__() self.conv1 = nn.Conv2D(in_channels=inplanes, channels=planes, kernel_size=1, use_bias=False) - self.bn1 = nn.BatchNorm(in_channels=planes) + self.bn1 = norm_layer(in_channels=planes) self.conv2 = nn.Conv2D( in_channels=planes, channels=planes, kernel_size=3, strides=strides, padding=dilation, dilation=dilation, use_bias=False) - self.bn2 = nn.BatchNorm(in_channels=planes) + self.bn2 = norm_layer(in_channels=planes) self.conv3 = nn.Conv2D( in_channels=planes, channels=planes * 4, kernel_size=1, use_bias=False) if not last_gamma: @@ -115,7 +115,7 @@ class ResNetV1b(HybridBlock): Applying dilation strategy to pretrained ResNet yielding a stride-8 model, typically used in Semantic Segmentation. norm_layer : object - Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; + Normalization layer used in backbone network (default: :class:`mxnet.gluon.norm_layer`; for Synchronized Cross-GPU BachNormalization). @@ -219,7 +219,7 @@ def resnet18_v1b(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs) dilated: bool, default False Whether to apply dilation strategy to ResNetV1b, yilding a stride 8 model. norm_layer : object - Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; + Normalization layer used in backbone network (default: :class:`mxnet.gluon.norm_layer`; for Synchronized Cross-GPU BachNormalization). """ model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs) @@ -244,7 +244,7 @@ def resnet34_v1b(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs) dilated: bool, default False Whether to apply dilation strategy to ResNetV1b, yilding a stride 8 model. norm_layer : object - Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; + Normalization layer used in backbone network (default: :class:`mxnet.gluon.norm_layer`; """ model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) if pretrained: @@ -268,7 +268,7 @@ def resnet50_v1b(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs) dilated: bool, default False Whether to apply dilation strategy to ResNetV1b, yilding a stride 8 model. norm_layer : object - Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; + Normalization layer used in backbone network (default: :class:`mxnet.gluon.norm_layer`; """ model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], **kwargs) if pretrained: @@ -292,7 +292,7 @@ def resnet101_v1b(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs dilated: bool, default False Whether to apply dilation strategy to ResNetV1b, yilding a stride 8 model. norm_layer : object - Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; + Normalization layer used in backbone network (default: :class:`mxnet.gluon.norm_layer`; """ model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], **kwargs) if pretrained: @@ -316,7 +316,7 @@ def resnet152_v1b(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs dilated: bool, default False Whether to apply dilation strategy to ResNetV1b, yilding a stride 8 model. norm_layer : object - Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; + Normalization layer used in backbone network (default: :class:`mxnet.gluon.norm_layer`; """ model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], **kwargs) if pretrained: diff --git a/gluoncv/model_zoo/syncbn.py b/gluoncv/model_zoo/syncbn.py index 5fab645ecb..c8a7d3f208 100644 --- a/gluoncv/model_zoo/syncbn.py +++ b/gluoncv/model_zoo/syncbn.py @@ -1,13 +1,111 @@ """Synchronized Cross GPU Batch Normalization""" import threading -from mxnet import autograd, test_utils -from mxnet.gluon import HybridBlock +import mxnet as mx +from mxnet import gluon, autograd, test_utils + +class SharedTensor(object): + """Shared Tensor for Syncing""" + def __init__(self, key, nchannels, num_devices): + self._mutex = threading.Lock() + self._all_tasks_done = threading.Condition(self._mutex) + self._key = key + self.num_devices = int(num_devices) + self.out = mx.nd.zeros(nchannels) + self._clear() + + def _clear(self): + self.list = [] + self.push_tasks = self.num_devices + self.reduce_tasks = self.num_devices + + def push(self, t): + """push value to SharedTensor""" + with self._mutex: + if self.push_tasks == 0: + self._clear() + #t.wait_to_read() + self.list.append(t) + self.push_tasks -= 1 + with self._all_tasks_done: + if self.push_tasks == 0: + self._all_tasks_done.notify_all() + while self.push_tasks: + self._all_tasks_done.wait() + + def _reduce(self, kv): + with self._mutex: + if self.reduce_tasks == 1: + assert(len(self.list) == self.num_devices) + kv.push(self._key, self.list) + self.reduce_tasks -= 1 + else: + self.reduce_tasks -= 1 + with self._all_tasks_done: + if self.reduce_tasks == 0: + self._all_tasks_done.notify_all() + while self.reduce_tasks: + self._all_tasks_done.wait() + + def pull(self, kv): + """Get value form SharedTensor""" + self._reduce(kv) + kv.pull(self._key, out=self.out) + return self.out -import numpy as np + def __len__(self): + return len(self.list) -class BatchNorm(HybridBlock): +class SharedTDict(object): + """Shared Dict for Syncing""" + def __init__(self): + self.stdict = {} + self.keys = [] + self._mutex = threading.Lock() + self.kv = mx.kv.create('local') + + def register(self, key, nchannels, num_devices): + with self._mutex: + if key in self.keys: + return + print('registerring {}'.format(key)) + self.stdict[key] = SharedTensor(key, nchannels, num_devices) + self.kv.init(key, mx.nd.zeros(nchannels)) + self.keys.append(key) + + def push(self, key, value): + self.stdict[key].push(value) + + def pull(self, key): + out = self.stdict[key].pull(self.kv) + return out + +sharedTensorDict = SharedTDict() + +class AllReduce(autograd.Function): + """All Reduce Operation""" + def __init__(self, key): + super(AllReduce, self).__init__() + self.xsumkey = key + 'sum' + self.xsqukey = key + 'squ' + + def forward(self, isum, isqu): + sharedTensorDict.push(self.xsumkey, isum) + sharedTensorDict.push(self.xsqukey, isqu) + osum = sharedTensorDict.pull(self.xsumkey).as_in_context(isum.context) + osqu = sharedTensorDict.pull(self.xsqukey).as_in_context(isqu.context) + return osum, osqu + + def backward(self, dsum, dsqu): + sharedTensorDict.push(self.xsumkey, dsum) + sharedTensorDict.push(self.xsqukey, dsqu) + disum = sharedTensorDict.pull(self.xsumkey).as_in_context(dsum.context) + disqu = sharedTensorDict.pull(self.xsqukey).as_in_context(dsqu.context) + return disum, disqu + + +class BatchNorm(gluon.nn.BatchNorm): """Cross-GPU Synchronized Batch normalization (SyncBN) Standard BN [1]_ implementation only normalize the data within each device. SyncBN normalizes the input within the whole mini-batch. @@ -47,7 +145,7 @@ class BatchNorm(HybridBlock): Number of channels (feature maps) in input data. If not specified, initialization will be deferred to the first time `forward` is called and `in_channels` will be inferred from the shape of input data. - nGPUs : int, default number of visible GPUs + num_devices : int, default number of visible GPUs Inputs: - **data**: input tensor with arbitrary shape. Outputs: @@ -63,107 +161,63 @@ class BatchNorm(HybridBlock): Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* """ # pylint: disable=arguments-differ - def __init__(self, momentum=0.9, epsilon=1e-5, center=True, scale=True, - beta_initializer='zeros', gamma_initializer='ones', - running_mean_initializer='zeros', running_variance_initializer='ones', - in_channels=0, nGPUs=None, **kwargs): - super(BatchNorm, self).__init__(**kwargs) - self._kwargs = {'eps': epsilon, 'momentum': momentum, - 'fix_gamma': not scale} - if in_channels != 0: - self.in_channels = in_channels + def __init__(self, in_channels, axis=1, momentum=0.9, epsilon=1e-5, ndevices=None, **kwargs): + super(BatchNorm, self).__init__(axis, momentum, epsilon, in_channels=in_channels, **kwargs) + self.eps = epsilon self.momentum = momentum - - self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', - shape=(in_channels,), init=gamma_initializer, - allow_deferred_init=True, - differentiable=scale) - self.beta = self.params.get('beta', grad_req='write' if center else 'null', - shape=(in_channels,), init=beta_initializer, - allow_deferred_init=True, - differentiable=center) - self.running_mean = self.params.get('running_mean', grad_req='null', - shape=(in_channels,), - init=running_mean_initializer, - allow_deferred_init=True, - differentiable=False) - self.running_var = self.params.get('running_var', grad_req='null', - shape=(in_channels,), - init=running_variance_initializer, - allow_deferred_init=True, - differentiable=False) - if nGPUs is None: - nGPUs = self._get_nGPUs() - self.xsum = _SharedTensor(nGPUs) - self.xsqu = _SharedTensor(nGPUs) - self.updater = _SharedUpdater(nGPUs) - - def _get_nGPUs(self): - # caution: if not using all the GPUs, please mannually set nGPUs - nGPUs = len(test_utils.list_gpus()) + self.in_channels = in_channels + self.ndevices = self._get_num_devices() if ndevices is None else ndevices + self.updater = _SharedUpdater(self.ndevices) + sharedTensorDict.register(self._prefix + 'sum', in_channels, self.ndevices) + sharedTensorDict.register(self._prefix + 'squ', in_channels, self.ndevices) + + def _get_num_devices(self): + # caution: if not using all the GPUs, please mannually set num_devices + num_devices = len(test_utils.list_gpus()) # for CPU - nGPUs = nGPUs if nGPUs > 0 else 1 - return nGPUs - - def cast(self, dtype): - if np.dtype(dtype).name == 'float16': - dtype = 'float32' - super(BatchNorm, self).cast(dtype) + num_devices = num_devices if num_devices > 0 else 1 + return num_devices def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): """Hybrid forward""" - if autograd.is_training(): - isum, isqu = F.SumSquare(x) - # reduce sum for E(x) and E(x^2) - idsum = self.xsum.push(isum) - idsqu = self.xsqu.push(isqu) - osum = self.xsum.get(F, idsum) - osqu = self.xsqu.get(F, idsqu) - assert(len(self.xsum) == len(self.xsqu)) - N = len(self.xsum)*x.shape[0]*x.shape[2]*x.shape[3] - # calc mean and std - mean = osum / N - sumvar = osqu - osum * osum / N - bias_var = sumvar / N - std = F.sqrt(F.clip(bias_var, a_min=self.eps, a_max=bias_var.max().asscalar())) - # update running mean and var - with autograd.pause(): - unbias_var = sumvar / (N - 1) - ctx = x.context - self.updater(self.running_mean, self.running_var, mean, unbias_var, - self.momentum, ctx) - return F.DecoupleBatchNorm(x, gamma, beta, mean, std, - name='fwd', **self._kwargs) - else: - ctx = x.context + if not autograd.is_training(): return F.BatchNorm(x, gamma, beta, running_mean, running_var, name='fwd', **self._kwargs) - - def __repr__(self): - s = '{name}({content}' - in_channels = self.gamma.shape[0] - s += ', in_channels={0}'.format(in_channels if in_channels else None) - s += ')' - - return s.format(name=self.__class__.__name__, - content=', '.join(['='.join([k, v.__repr__()]) - for k, v in self._kwargs.items()])) + isum, isqu = F.SumSquare(x) + #isum = x.sum(axis=1, exclude=True) + #isqu = (x**2).sum(axis=1, exclude=True) + N = self.ndevices * x.shape[0] * x.shape[2] * x.shape[3] + allreduce = AllReduce(self._prefix) + osum, osqu = allreduce(isum, isqu) + # calc mean and std + mean = osum / N + sumvar = osqu - osum * osum / N + bias_var = sumvar / N + std = F.sqrt(F.maximum(bias_var, self.eps)) + # update running mean and var + with autograd.pause(): + unbias_var = sumvar / (N - 1) + self.updater(self.running_mean, self.running_var, mean, unbias_var, + self.momentum, x.context) + # update running mean and var + output = F.DecoupleBatchNorm(x, gamma, beta, mean, std) + return output class _SharedUpdater(object): # update only once - def __init__(self, nGPUs): - self.mutex = threading.Lock() - self.nGPUs = nGPUs + def __init__(self, num_devices): + self._mutex = threading.Lock() + self.num_devices = num_devices self._clear() def _clear(self): - self.tasks = self.nGPUs + self.tasks = self.num_devices def __call__(self, running_mean, running_var, mean, unbias_var, momentum, ctx): - with self.mutex: - if self.tasks == self.nGPUs: + with self._mutex: + if self.tasks == self.num_devices: running_mean.set_data(momentum * running_mean.data(ctx) + \ (1.0 - momentum) * mean) running_var.set_data(momentum * running_var.data(ctx) + \ @@ -171,65 +225,3 @@ def __call__(self, running_mean, running_var, mean, unbias_var, momentum, ctx): self.tasks -= 1 if self.tasks == 0: self._clear() - - -class _SharedTensor(object): - def __init__(self, nGPUs): - self.mutex = threading.Lock() - self.all_tasks_done = threading.Condition(self.mutex) - self.nGPUs = nGPUs - self._clear() - - def _clear(self): - self.list = [] - self.push_tasks = self.nGPUs - self.reduce_tasks = self.nGPUs - - def push(self, t): - """push to _SharedTensor""" - with self.mutex: - if self.push_tasks == 0: - self._clear() - self.list.append(t) - idx = len(self.list) - 1 - self.push_tasks -= 1 - - with self.all_tasks_done: - if self.push_tasks == 0: - self.all_tasks_done.notify_all() - while self.push_tasks: - self.all_tasks_done.wait() - return idx - - def _reduce(self, F): - with self.mutex: - if self.reduce_tasks == 1: - assert(len(self.list) == self.nGPUs) - self.list = F.AllReduce(*self.list) - for xi in self.list: - # mannually attach grad to avoid wrong allocation - xi.attach_grad() - xi.wait_to_read() - self.reduce_tasks -= 1 - else: - self.reduce_tasks -= 1 - - with self.all_tasks_done: - if self.reduce_tasks == 0: - self.all_tasks_done.notify_all() - while self.reduce_tasks: - self.all_tasks_done.wait() - - def get(self, F, idx): - """Get form _SharedTensor""" - self._reduce(F) - return self.list[idx] - - def test(self): - print('self.list', self.list) - - def __len__(self): - return len(self.list) - - def __repr__(self): - return '_SharedTensor' diff --git a/gluoncv/utils/metrics/voc_segmentation.py b/gluoncv/utils/metrics/voc_segmentation.py index 15a7bc2cf8..2e5bf969ee 100644 --- a/gluoncv/utils/metrics/voc_segmentation.py +++ b/gluoncv/utils/metrics/voc_segmentation.py @@ -8,10 +8,12 @@ def batch_pix_accuracy(output, target): """PixAcc""" # inputs are NDarray, output 4D, target 3D - predict = F.argmax(output, 1) + 1 - target = target.astype(predict.dtype) + 1 - pixel_labeled = (target > 0).sum().asscalar() - pixel_correct = (F.equal(predict, target)*(target > 0)).sum().asscalar() + predict = F.argmax(output, 1) + predict = predict.asnumpy() + 1 + target = target.asnumpy().astype(predict.dtype) + 1 + pixel_labeled = np.sum(target > 0) + pixel_correct = np.sum((predict == target)*(target > 0)) + assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" return pixel_correct, pixel_labeled diff --git a/gluoncv/utils/parallel.py b/gluoncv/utils/parallel.py index ae682f82fa..33b7241be1 100644 --- a/gluoncv/utils/parallel.py +++ b/gluoncv/utils/parallel.py @@ -1,12 +1,12 @@ -# pylint: disable=consider-using-enumerate,redefined-builtin,broad-except """Utils for Semantic Segmentation""" +# pylint: disable=consider-using-enumerate,redefined-builtin,broad-except import threading from mxnet import autograd from mxnet.ndarray import NDArray from mxnet.gluon.utils import split_and_load -__all__ = ['DataParallelModel', 'DataParallelCriterion'] +__all__ = ['DataParallelModel', 'DataParallelCriterion', 'parallel_backward'] class DataParallelModel(object): """Data parallelism @@ -238,3 +238,17 @@ def _worker(i, module, input, target, kwargs, results, is_recording, is_training outputs = [module(*(input + target), **kwargs) \ for (input, target, kwargs) in zip(inputs, targets, kwargs_tup)] return tuple(outputs) + +def parallel_backward(losses, sync=True): + """Parallel Backward for CustomOp""" + def _worker(loss): + autograd.backward(loss) + threads = [threading.Thread(target=_worker, args=(loss,)) for loss in losses] + if sync: + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + for loss in losses: + loss.backward() diff --git a/scripts/segmentation/test.py b/scripts/segmentation/test.py index d6c2aa7899..4deea3e668 100644 --- a/scripts/segmentation/test.py +++ b/scripts/segmentation/test.py @@ -75,7 +75,7 @@ def test(args): im_paths = dsts predicts = evaluator.parallel_forward(data) for predict, impath in zip(predicts, im_paths): - predict = mx.nd.squeeze(mx.nd.argmax(predict[0], 1)).asnumpy() + predict = mx.nd.squeeze(mx.nd.argmax(predict[0], 1)).asnumpy() + testset.pred_offset mask = get_color_pallete(predict, args.dataset) outname = os.path.splitext(impath)[0] + '.png' mask.save(os.path.join(outdir, outname))