diff --git a/BENCHMARK.md b/BENCHMARK.md index 5af7fc56..d0bbd63e 100644 --- a/BENCHMARK.md +++ b/BENCHMARK.md @@ -29,6 +29,72 @@ ARl AR for large objects: area > 962 +## RetinaNet +### retinanet-R-50-FPN_1x + +- Training command: + + ``` + python tools/train_net_step.py \ + --dataset coco2017 --cfg configs/baselines/retinanet_R-50-FPN_1x.yaml \ + --bs 8 --iter_size 1 --use_tfboard + ``` + on four V100 GPUs. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Box
sourceAP50:95AP50AP75APsAPmAPlAR1AR10AR100ARsARmARl
PyTorch35.354.637.919.439.147.530.748.951.832.456.367.4
Detectron35.754.738.519.539.947.530.749.152.032.056.968.0
+ +- Total loss comparison: + + ![img](demo/loss_retinanet_R-50-FPN_1x.jpg) + + ## Faster-RCNN ### e2e_faster_rcnn-R-50-FPN_1x diff --git a/configs/baselines/retinanet_R-50-FPN_1x.yaml b/configs/baselines/retinanet_R-50-FPN_1x.yaml new file mode 100644 index 00000000..989807ac --- /dev/null +++ b/configs/baselines/retinanet_R-50-FPN_1x.yaml @@ -0,0 +1,41 @@ +DEBUG: False +MODEL: + TYPE: retinanet + CONV_BODY: FPN.fpn_ResNet50_conv5_body + NUM_CLASSES: 81 +RESNETS: + IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_networks/resnet50_caffe.pth' +NUM_GPUS: 8 +SOLVER: + WEIGHT_DECAY: 0.0001 + LR_POLICY: steps_with_decay + BASE_LR: 0.01 + GAMMA: 0.1 + MAX_ITER: 90000 + STEPS: [0, 60000, 80000] +FPN: + FPN_ON: True + MULTILEVEL_RPN: True + RPN_MAX_LEVEL: 7 + RPN_MIN_LEVEL: 3 + COARSEST_STRIDE: 128 + EXTRA_CONV_LEVELS: True +RETINANET: + RETINANET_ON: True + NUM_CONVS: 4 + ASPECT_RATIOS: (1.0, 2.0, 0.5) + SCALES_PER_OCTAVE: 3 + ANCHOR_SCALE: 4 + LOSS_GAMMA: 2.0 + LOSS_ALPHA: 0.25 +TRAIN: + SCALES: (800,) + MAX_SIZE: 1333 + RPN_STRADDLE_THRESH: -1 # default 0 +TEST: + SCALE: 800 + MAX_SIZE: 1333 + NMS: 0.5 + RPN_PRE_NMS_TOP_N: 10000 # Per FPN level + RPN_POST_NMS_TOP_N: 2000 +OUTPUT_DIR: . \ No newline at end of file diff --git a/demo/loss_retinanet_R-50-FPN_1x.jpg b/demo/loss_retinanet_R-50-FPN_1x.jpg new file mode 100644 index 00000000..921df780 Binary files /dev/null and b/demo/loss_retinanet_R-50-FPN_1x.jpg differ diff --git a/lib/core/test.py b/lib/core/test.py index eac67355..0dff62e3 100644 --- a/lib/core/test.py +++ b/lib/core/test.py @@ -45,6 +45,7 @@ import utils.fpn as fpn_utils import utils.image as image_utils import utils.keypoints as keypoint_utils +import core.test_retinanet as test_retinanet def im_detect_all(model, im, box_proposals=None, timers=None): @@ -62,6 +63,13 @@ def im_detect_all(model, im, box_proposals=None, timers=None): if timers is None: timers = defaultdict(Timer) + # Handle RetinaNet testing separately for now + if cfg.RETINANET.RETINANET_ON: + timers['im_detect_bbox'].tic() + cls_boxes = test_retinanet.im_detect_bbox(model, im, timers) + timers['im_detect_bbox'].toc() + return cls_boxes, None, None + timers['im_detect_bbox'].tic() if cfg.TEST.BBOX_AUG.ENABLED: scores, boxes, im_scale, blob_conv = im_detect_bbox_aug( diff --git a/lib/core/test_retinanet.py b/lib/core/test_retinanet.py new file mode 100644 index 00000000..6d2f1dc9 --- /dev/null +++ b/lib/core/test_retinanet.py @@ -0,0 +1,196 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +"""Test a RetinaNet network on an image database""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import logging +from collections import defaultdict + +from torch.autograd import Variable +import torch + +from core.config import cfg +from modeling.generate_anchors import generate_anchors +from utils.timer import Timer +import utils.blob as blob_utils +import utils.boxes as box_utils +import roi_data.data_utils as data_utils + +logger = logging.getLogger(__name__) + + +def _create_cell_anchors(): + """ + Generate all types of anchors for all fpn levels/scales/aspect ratios. + This function is called only once at the beginning of inference. + """ + k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL + scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE + aspect_ratios = cfg.RETINANET.ASPECT_RATIOS + anchor_scale = cfg.RETINANET.ANCHOR_SCALE + A = scales_per_octave * len(aspect_ratios) + + anchors = {} + for lvl in range(k_min, k_max + 1): + # create cell anchors array + stride = 2. ** lvl + cell_anchors = np.zeros((A, 4)) + a = 0 + for octave in range(scales_per_octave): + octave_scale = 2 ** (octave / float(scales_per_octave)) + for aspect in aspect_ratios: + anchor_sizes = (stride * octave_scale * anchor_scale, ) + anchor_aspect_ratios = (aspect, ) + cell_anchors[a, :] = generate_anchors( + stride=stride, sizes=anchor_sizes, + aspect_ratios=anchor_aspect_ratios) + a += 1 + anchors[lvl] = cell_anchors + return anchors + + +def im_detect_bbox(model, im, timers=None): + """Generate RetinaNet detections on a single image.""" + if timers is None: + timers = defaultdict(Timer) + # Although anchors are input independent and could be precomputed, + # recomputing them per image only brings a small overhead + anchors = _create_cell_anchors() + timers['im_detect_bbox'].tic() + k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL + A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS) + inputs = {} + inputs['data'], im_scale, inputs['im_info'] = \ + blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE) + + if cfg.PYTORCH_VERSION_LESS_THAN_040: + inputs['data'] = [ + Variable(torch.from_numpy(inputs['data']), volatile=True)] + inputs['im_info'] = [ + Variable(torch.from_numpy(inputs['im_info']), volatile=True)] + else: + inputs['data'] = [torch.from_numpy(inputs['data'])] + inputs['im_info'] = [torch.from_numpy(inputs['im_info'])] + + return_dict = model(**inputs) + cls_probs = return_dict['cls_score'] + box_preds = return_dict['bbox_pred'] + + # here the boxes_all are [x0, y0, x1, y1, score] + boxes_all = defaultdict(list) + + cnt = 0 + for lvl in range(k_min, k_max + 1): + # create cell anchors array + stride = 2. ** lvl + cell_anchors = anchors[lvl] + + # fetch per level probability + cls_prob = cls_probs[cnt].data.cpu().numpy() + box_pred = box_preds[cnt].data.cpu().numpy() + cls_prob = cls_prob.reshape(( + cls_prob.shape[0], A, int(cls_prob.shape[1] / A), + cls_prob.shape[2], cls_prob.shape[3])) + box_pred = box_pred.reshape(( + box_pred.shape[0], A, 4, box_pred.shape[2], box_pred.shape[3])) + cnt += 1 + + if cfg.RETINANET.SOFTMAX: + cls_prob = cls_prob[:, :, 1::, :, :] + + cls_prob_ravel = cls_prob.ravel() + # In some cases [especially for very small img sizes], it's possible that + # candidate_ind is empty if we impose threshold 0.05 at all levels. This + # will lead to errors since no detections are found for this image. Hence, + # for lvl 7 which has small spatial resolution, we take the threshold 0.0 + th = cfg.RETINANET.INFERENCE_TH if lvl < k_max else 0.0 + candidate_inds = np.where(cls_prob_ravel > th)[0] + if (len(candidate_inds) == 0): + continue + + pre_nms_topn = min(cfg.RETINANET.PRE_NMS_TOP_N, len(candidate_inds)) + inds = np.argpartition( + cls_prob_ravel[candidate_inds], -pre_nms_topn)[-pre_nms_topn:] + inds = candidate_inds[inds] + + inds_5d = np.array(np.unravel_index(inds, cls_prob.shape)).transpose() + classes = inds_5d[:, 2] + anchor_ids, y, x = inds_5d[:, 1], inds_5d[:, 3], inds_5d[:, 4] + scores = cls_prob[:, anchor_ids, classes, y, x] + + boxes = np.column_stack((x, y, x, y)).astype(dtype=np.float32) + boxes *= stride + boxes += cell_anchors[anchor_ids, :] + + if not cfg.RETINANET.CLASS_SPECIFIC_BBOX: + box_deltas = box_pred[0, anchor_ids, :, y, x] + else: + box_cls_inds = classes * 4 + box_deltas = np.vstack( + [box_pred[0, ind:ind + 4, yi, xi] + for ind, yi, xi in zip(box_cls_inds, y, x)] + ) + pred_boxes = ( + box_utils.bbox_transform(boxes, box_deltas) + if cfg.TEST.BBOX_REG else boxes) + pred_boxes /= im_scale + pred_boxes = box_utils.clip_tiled_boxes(pred_boxes, im.shape) + box_scores = np.zeros((pred_boxes.shape[0], 5)) + box_scores[:, 0:4] = pred_boxes + box_scores[:, 4] = scores + + for cls in range(1, cfg.MODEL.NUM_CLASSES): + inds = np.where(classes == cls - 1)[0] + if len(inds) > 0: + boxes_all[cls].extend(box_scores[inds, :]) + timers['im_detect_bbox'].toc() + + # Combine predictions across all levels and retain the top scoring by class + timers['misc_bbox'].tic() + detections = [] + for cls, boxes in boxes_all.items(): + cls_dets = np.vstack(boxes).astype(dtype=np.float32) + # do class specific nms here + keep = box_utils.nms(cls_dets, cfg.TEST.NMS) + cls_dets = cls_dets[keep, :] + out = np.zeros((len(keep), 6)) + out[:, 0:5] = cls_dets + out[:, 5].fill(cls) + detections.append(out) + + # detections (N, 6) format: + # detections[:, :4] - boxes + # detections[:, 4] - scores + # detections[:, 5] - classes + detections = np.vstack(detections) + # sort all again + inds = np.argsort(-detections[:, 4]) + detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :] + + # Convert the detections to image cls_ format (see core/test_engine.py) + num_classes = cfg.MODEL.NUM_CLASSES + cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)] + for c in range(1, num_classes): + inds = np.where(detections[:, 5] == c)[0] + cls_boxes[c] = detections[inds, :5] + timers['misc_bbox'].toc() + + return cls_boxes diff --git a/lib/modeling/FPN.py b/lib/modeling/FPN.py index 03cd60bc..a551699c 100644 --- a/lib/modeling/FPN.py +++ b/lib/modeling/FPN.py @@ -140,7 +140,7 @@ def __init__(self, conv_body_func, fpn_level_info, P2only=False): self.extra_pyramid_modules = nn.ModuleList() dim_in = fpn_level_info.dims[0] for i in range(HIGHEST_BACKBONE_LVL + 1, max_level + 1): - self.extra_pyramid_modules( + self.extra_pyramid_modules.append( nn.Conv2d(dim_in, fpn_dim, 3, 2, 1) ) dim_in = fpn_dim @@ -214,7 +214,7 @@ def detectron_weight_mapping(self): }) if hasattr(self, 'extra_pyramid_modules'): - for i in len(self.extra_pyramid_modules): + for i in range(len(self.extra_pyramid_modules)): p_prefix = 'extra_pyramid_modules.%d' % i d_prefix = 'fpn_%d' % (HIGHEST_BACKBONE_LVL + 1 + i) mapping_to_detectron.update({ @@ -246,9 +246,9 @@ def forward(self, x): if hasattr(self, 'extra_pyramid_modules'): blob_in = conv_body_blobs[-1] - fpn_output_blobs.insert(0, self.extra_pyramid_modules(blob_in)) + fpn_output_blobs.insert(0, self.extra_pyramid_modules[0](blob_in)) for module in self.extra_pyramid_modules[1:]: - fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0], inplace=True))) + fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0]))) if self.P2only: # use only the finest level @@ -294,7 +294,7 @@ def forward(self, top_blob, lateral_blob): lat = self.conv_lateral(lateral_blob) # Top-down 2x upsampling # td = F.upsample(top_blob, size=lat.size()[2:], mode='bilinear') - td = F.upsample(top_blob, scale_factor=2, mode='nearest') + td = F.interpolate(top_blob, scale_factor=2, mode='nearest') # Sum lateral and top-down return lat + td diff --git a/lib/modeling/model_builder.py b/lib/modeling/model_builder.py index 0c7f1f49..59a274ef 100644 --- a/lib/modeling/model_builder.py +++ b/lib/modeling/model_builder.py @@ -12,6 +12,7 @@ from model.roi_crop.functions.roi_crop import RoICropFunction from modeling.roi_xfrom.roi_align.functions.roi_align import RoIAlignFunction import modeling.rpn_heads as rpn_heads +import modeling.retinanet_heads as retinanet_heads import modeling.fast_rcnn_heads as fast_rcnn_heads import modeling.mask_rcnn_heads as mask_rcnn_heads import modeling.keypoint_rcnn_heads as keypoint_rcnn_heads @@ -62,8 +63,9 @@ def wrapper(self, *args, **kwargs): with torch.no_grad(): return net_func(self, *args, **kwargs) else: - raise ValueError('You should call this function only on inference.' - 'Set the network in inference mode by net.eval().') + raise ValueError( + 'You should call this function only on inference.' + 'Set the network in inference mode by net.eval().') return wrapper @@ -84,42 +86,59 @@ def __init__(self): self.RPN = rpn_heads.generic_rpn_outputs( self.Conv_Body.dim_out, self.Conv_Body.spatial_scale) - if cfg.FPN.FPN_ON: - # Only supports case when RPN and ROI min levels are the same - assert cfg.FPN.RPN_MIN_LEVEL == cfg.FPN.ROI_MIN_LEVEL - # RPN max level can be >= to ROI max level - assert cfg.FPN.RPN_MAX_LEVEL >= cfg.FPN.ROI_MAX_LEVEL - # FPN RPN max level might be > FPN ROI max level in which case we - # need to discard some leading conv blobs (blobs are ordered from - # max/coarsest level to min/finest level) - self.num_roi_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1 - - # Retain only the spatial scales that will be used for RoI heads. `Conv_Body.spatial_scale` - # may include extra scales that are used for RPN proposals, but not for RoI heads. - self.Conv_Body.spatial_scale = self.Conv_Body.spatial_scale[-self.num_roi_levels:] + if cfg.FPN.FPN_ON: + # Only supports case when RPN and ROI min levels are the same + assert cfg.FPN.RPN_MIN_LEVEL == cfg.FPN.ROI_MIN_LEVEL + # RPN max level can be >= to ROI max level + assert cfg.FPN.RPN_MAX_LEVEL >= cfg.FPN.ROI_MAX_LEVEL + # FPN RPN max level might be > FPN ROI max level in which case we + # need to discard some leading conv blobs (blobs are ordered from + # max/coarsest level to min/finest level) + self.num_roi_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1 + + # Retain only the spatial scales that will be used for RoI heads. `Conv_Body.spatial_scale` + # may include extra scales that are used for RPN proposals, but + # not for RoI heads. + self.Conv_Body.spatial_scale = self.Conv_Body.spatial_scale[-self.num_roi_levels:] # BBOX Branch if not cfg.MODEL.RPN_ONLY: - self.Box_Head = get_func(cfg.FAST_RCNN.ROI_BOX_HEAD)( - self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) - self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs( - self.Box_Head.dim_out) + if cfg.FAST_RCNN.ROI_BOX_HEAD is '': + # RetinaNet + self.Box_Outs = retinanet_heads.fpn_retinanet_outputs( + self.Conv_Body.dim_out, self.Conv_Body.spatial_scale) + else: + self.Box_Head = get_func( + cfg.FAST_RCNN.ROI_BOX_HEAD)( + self.RPN.dim_out, + self.roi_feature_transform, + self.Conv_Body.spatial_scale) + self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs( + self.Box_Head.dim_out) # Mask Branch if cfg.MODEL.MASK_ON: - self.Mask_Head = get_func(cfg.MRCNN.ROI_MASK_HEAD)( - self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) + self.Mask_Head = get_func( + cfg.MRCNN.ROI_MASK_HEAD)( + self.RPN.dim_out, + self.roi_feature_transform, + self.Conv_Body.spatial_scale) if getattr(self.Mask_Head, 'SHARE_RES5', False): self.Mask_Head.share_res5_module(self.Box_Head.res5) - self.Mask_Outs = mask_rcnn_heads.mask_rcnn_outputs(self.Mask_Head.dim_out) + self.Mask_Outs = mask_rcnn_heads.mask_rcnn_outputs( + self.Mask_Head.dim_out) # Keypoints Branch if cfg.MODEL.KEYPOINTS_ON: - self.Keypoint_Head = get_func(cfg.KRCNN.ROI_KEYPOINTS_HEAD)( - self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) + self.Keypoint_Head = get_func( + cfg.KRCNN.ROI_KEYPOINTS_HEAD)( + self.RPN.dim_out, + self.roi_feature_transform, + self.Conv_Body.spatial_scale) if getattr(self.Keypoint_Head, 'SHARE_RES5', False): self.Keypoint_Head.share_res5_module(self.Box_Head.res5) - self.Keypoint_Outs = keypoint_rcnn_heads.keypoint_outputs(self.Keypoint_Head.dim_out) + self.Keypoint_Outs = keypoint_rcnn_heads.keypoint_outputs( + self.Keypoint_Head.dim_out) self._init_modules() @@ -127,10 +146,16 @@ def _init_modules(self): if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS: resnet_utils.load_pretrained_imagenet_weights(self) # Check if shared weights are equaled - if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5', False): - assert compare_state_dict(self.Mask_Head.res5.state_dict(), self.Box_Head.res5.state_dict()) - if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head, 'SHARE_RES5', False): - assert compare_state_dict(self.Keypoint_Head.res5.state_dict(), self.Box_Head.res5.state_dict()) + if cfg.MODEL.MASK_ON and getattr( + self.Mask_Head, 'SHARE_RES5', False): + assert compare_state_dict( + self.Mask_Head.res5.state_dict(), + self.Box_Head.res5.state_dict()) + if cfg.MODEL.KEYPOINTS_ON and getattr( + self.Keypoint_Head, 'SHARE_RES5', False): + assert compare_state_dict( + self.Keypoint_Head.res5.state_dict(), + self.Box_Head.res5.state_dict()) if cfg.TRAIN.FREEZE_CONV_BODY: for p in self.Conv_Body.parameters(): @@ -154,25 +179,30 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): blob_conv = self.Conv_Body(im_data) - rpn_ret = self.RPN(blob_conv, im_info, roidb) - # if self.training: # # can be used to infer fg/bg ratio # return_dict['rois_label'] = rpn_ret['labels_int32'] - if cfg.FPN.FPN_ON: + if cfg.RPN.RPN_ON: + rpn_ret = self.RPN(blob_conv, im_info, roidb) + + if cfg.FPN.FPN_ON and cfg.FAST_RCNN.ROI_BOX_HEAD is not '': # Retain only the blobs that will be used for RoI heads. `blob_conv` may include - # extra blobs that are used for RPN proposals, but not for RoI heads. + # extra blobs that are used for RPN proposals, but not for RoI + # heads. blob_conv = blob_conv[-self.num_roi_levels:] if not self.training: return_dict['blob_conv'] = blob_conv if not cfg.MODEL.RPN_ONLY: - if cfg.MODEL.SHARE_RES5 and self.training: - box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret) + if cfg.FAST_RCNN.ROI_BOX_HEAD is not '': + if cfg.MODEL.SHARE_RES5 and self.training: + box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret) + else: + box_feat = self.Box_Head(blob_conv, rpn_ret) else: - box_feat = self.Box_Head(blob_conv, rpn_ret) + box_feat = blob_conv cls_score, bbox_pred = self.Box_Outs(box_feat) else: # TODO: complete the returns for RPN only situation @@ -181,46 +211,62 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): if self.training: return_dict['losses'] = {} return_dict['metrics'] = {} - # rpn loss - rpn_kwargs.update(dict( - (k, rpn_ret[k]) for k in rpn_ret.keys() - if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred')) - )) - loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs) - if cfg.FPN.FPN_ON: - for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)): - return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i] - return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i] - else: - return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls - return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox - # bbox loss - loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses( - cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'], - rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights']) - return_dict['losses']['loss_cls'] = loss_cls - return_dict['losses']['loss_bbox'] = loss_bbox - return_dict['metrics']['accuracy_cls'] = accuracy_cls + if cfg.FAST_RCNN.ROI_BOX_HEAD is not '': + # rpn loss + rpn_kwargs.update(dict((k, rpn_ret[k]) for k in rpn_ret.keys() if ( + k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred')))) + loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses( + **rpn_kwargs) + if cfg.FPN.FPN_ON: + for i, lvl in enumerate( + range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)): + return_dict['losses']['loss_rpn_cls_fpn%d' % + lvl] = loss_rpn_cls[i] + return_dict['losses']['loss_rpn_bbox_fpn%d' % + lvl] = loss_rpn_bbox[i] + else: + return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls + return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox + + # bbox loss + loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses( + cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'], + rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights']) + return_dict['losses']['loss_cls'] = loss_cls + return_dict['losses']['loss_bbox'] = loss_bbox + return_dict['metrics']['accuracy_cls'] = accuracy_cls + + if cfg.RETINANET.RETINANET_ON: + loss_retnet_cls, loss_retnet_bbox = retinanet_heads.add_fpn_retinanet_losses( + cls_score, bbox_pred, **rpn_kwargs) + for i, lvl in enumerate( + range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)): + return_dict['losses']['loss_retnet_cls_fpn%d' % + lvl] = loss_retnet_cls[i] + return_dict['losses']['loss_retnet_bbox_fpn%d' % + lvl] = loss_retnet_bbox[i] if cfg.MODEL.MASK_ON: if getattr(self.Mask_Head, 'SHARE_RES5', False): - mask_feat = self.Mask_Head(res5_feat, rpn_ret, - roi_has_mask_int32=rpn_ret['roi_has_mask_int32']) + mask_feat = self.Mask_Head( + res5_feat, rpn_ret, roi_has_mask_int32=rpn_ret['roi_has_mask_int32']) else: mask_feat = self.Mask_Head(blob_conv, rpn_ret) mask_pred = self.Mask_Outs(mask_feat) # return_dict['mask_pred'] = mask_pred # mask loss - loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32']) + loss_mask = mask_rcnn_heads.mask_rcnn_losses( + mask_pred, rpn_ret['masks_int32']) return_dict['losses']['loss_mask'] = loss_mask if cfg.MODEL.KEYPOINTS_ON: if getattr(self.Keypoint_Head, 'SHARE_RES5', False): # No corresponding keypoint head implemented yet (Neither in Detectron) - # Also, rpn need to generate the label 'roi_has_keypoints_int32' - kps_feat = self.Keypoint_Head(res5_feat, rpn_ret, - roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32']) + # Also, rpn need to generate the label + # 'roi_has_keypoints_int32' + kps_feat = self.Keypoint_Head( + res5_feat, rpn_ret, roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32']) else: kps_feat = self.Keypoint_Head(blob_conv, rpn_ret) kps_pred = self.Keypoint_Outs(kps_feat) @@ -243,14 +289,22 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): else: # Testing - return_dict['rois'] = rpn_ret['rois'] + if cfg.FAST_RCNN.ROI_BOX_HEAD is not '': + return_dict['rois'] = rpn_ret['rois'] return_dict['cls_score'] = cls_score return_dict['bbox_pred'] = bbox_pred return return_dict - def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoIPoolF', - resolution=7, spatial_scale=1. / 16., sampling_ratio=0): + def roi_feature_transform( + self, + blobs_in, + rpn_ret, + blob_rois='rois', + method='RoIPoolF', + resolution=7, + spatial_scale=1. / 16., + sampling_ratio=0): """Add the specified RoI pooling method. The sampling_ratio argument is supported for some, but not all, RoI transform methods. @@ -273,12 +327,16 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI sc = spatial_scale[k_max - lvl] # in reversed order bl_rois = blob_rois + '_fpn' + str(lvl) if len(rpn_ret[bl_rois]): - rois = Variable(torch.from_numpy(rpn_ret[bl_rois])).cuda(device_id) + rois = Variable(torch.from_numpy( + rpn_ret[bl_rois])).cuda(device_id) if method == 'RoIPoolF': - # Warning!: Not check if implementation matches Detectron - xform_out = RoIPoolFunction(resolution, resolution, sc)(bl_in, rois) + # Warning!: Not check if implementation matches + # Detectron + xform_out = RoIPoolFunction( + resolution, resolution, sc)(bl_in, rois) elif method == 'RoICrop': - # Warning!: Not check if implementation matches Detectron + # Warning!: Not check if implementation matches + # Detectron grid_xy = net_utils.affine_grid_gen( rois, bl_in.size()[2:], self.grid_size) grid_yx = torch.stack( @@ -288,7 +346,8 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI xform_out = F.max_pool2d(xform_out, 2, 2) elif method == 'RoIAlign': xform_out = RoIAlignFunction( - resolution, resolution, sc, sampling_ratio)(bl_in, rois) + resolution, resolution, sc, sampling_ratio)( + bl_in, rois) bl_out_list.append(xform_out) # The pooled features from all levels are concatenated along the @@ -299,7 +358,10 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI device_id = xform_shuffled.get_device() restore_bl = rpn_ret[blob_rois + '_idx_restore_int32'] restore_bl = Variable( - torch.from_numpy(restore_bl.astype('int64', copy=False))).cuda(device_id) + torch.from_numpy( + restore_bl.astype( + 'int64', + copy=False))).cuda(device_id) xform_out = xform_shuffled[restore_bl] else: # Single feature level @@ -307,11 +369,14 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI # (batch_idx, x1, y1, x2, y2) specifying an image batch index and a # rectangle (x1, y1, x2, y2) device_id = blobs_in.get_device() - rois = Variable(torch.from_numpy(rpn_ret[blob_rois])).cuda(device_id) + rois = Variable(torch.from_numpy( + rpn_ret[blob_rois])).cuda(device_id) if method == 'RoIPoolF': - xform_out = RoIPoolFunction(resolution, resolution, spatial_scale)(blobs_in, rois) + xform_out = RoIPoolFunction( + resolution, resolution, spatial_scale)(blobs_in, rois) elif method == 'RoICrop': - grid_xy = net_utils.affine_grid_gen(rois, blobs_in.size()[2:], self.grid_size) + grid_xy = net_utils.affine_grid_gen( + rois, blobs_in.size()[2:], self.grid_size) grid_yx = torch.stack( [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous() xform_out = RoICropFunction()(blobs_in, Variable(grid_yx).detach()) @@ -319,7 +384,12 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI xform_out = F.max_pool2d(xform_out, 2, 2) elif method == 'RoIAlign': xform_out = RoIAlignFunction( - resolution, resolution, spatial_scale, sampling_ratio)(blobs_in, rois) + resolution, + resolution, + spatial_scale, + sampling_ratio)( + blobs_in, + rois) return xform_out @@ -329,7 +399,8 @@ def convbody_net(self, data): blob_conv = self.Conv_Body(data) if cfg.FPN.FPN_ON: # Retain only the blobs that will be used for RoI heads. `blob_conv` may include - # extra blobs that are used for RPN proposals, but not for RoI heads. + # extra blobs that are used for RPN proposals, but not for RoI + # heads. blob_conv = blob_conv[-self.num_roi_levels:] return blob_conv diff --git a/lib/modeling/retinanet_heads.py b/lib/modeling/retinanet_heads.py new file mode 100644 index 00000000..a4cdbb81 --- /dev/null +++ b/lib/modeling/retinanet_heads.py @@ -0,0 +1,207 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +"""RetinaNet model heads and losses. See: https://arxiv.org/abs/1708.02002.""" +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +import utils.net as net_utils +import math + +from core.config import cfg + + +class fpn_retinanet_outputs(nn.Module): + """Add RetinaNet on FPN specific outputs.""" + + def __init__(self, dim_in, spatial_scales): + super().__init__() + self.dim_out = dim_in + self.dim_in = dim_in + self.spatial_scales = spatial_scales + self.dim_out = self.dim_in + self.num_anchors = len(cfg.RETINANET.ASPECT_RATIOS) * \ + cfg.RETINANET.SCALES_PER_OCTAVE + + # Create conv ops shared by all FPN levels + self.n_conv_fpn_cls_modules = nn.ModuleList() + self.n_conv_fpn_bbox_modules = nn.ModuleList() + for nconv in range(cfg.RETINANET.NUM_CONVS): + self.n_conv_fpn_cls_modules.append( + nn.Conv2d(self.dim_in, self.dim_out, 3, 1, 1)) + self.n_conv_fpn_bbox_modules.append( + nn.Conv2d(self.dim_in, self.dim_out, 3, 1, 1)) + + cls_pred_dim = cfg.MODEL.NUM_CLASSES if cfg.RETINANET.SOFTMAX \ + else (cfg.MODEL.NUM_CLASSES - 1) + + # unpacked bbox feature and add prediction layers + self.bbox_regr_dim = 4 * (cfg.MODEL.NUM_CLASSES - 1) \ + if cfg.RETINANET.CLASS_SPECIFIC_BBOX else 4 + + self.fpn_cls_score = nn.Conv2d(self.dim_out, + cls_pred_dim * self.num_anchors, 3, 1, 1) + self.fpn_bbox_score = nn.Conv2d(self.dim_out, + self.bbox_regr_dim * self.num_anchors, 3, 1, 1) + + self._init_weights() + + def _init_weights(self): + def init_func(m): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + for child_m in self.children(): + if isinstance(child_m, nn.ModuleList): + child_m.apply(init_func) + + init.normal_(self.fpn_cls_score.weight, std=0.01) + init.constant_(self.fpn_cls_score.bias, + -math.log((1 - cfg.RETINANET.PRIOR_PROB) / cfg.RETINANET.PRIOR_PROB)) + + init.normal_(self.fpn_bbox_score.weight, std=0.01) + init.constant_(self.fpn_bbox_score.bias, 0) + + def detectron_weight_mapping(self): + k_min = cfg.FPN.RPN_MIN_LEVEL + mapping_to_detectron = { + 'n_conv_fpn_cls_modules.0.weight': 'retnet_cls_conv_n0_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.0.bias': 'retnet_cls_conv_n0_fpn%d_b' % k_min, + 'n_conv_fpn_cls_modules.1.weight': 'retnet_cls_conv_n1_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.1.bias': 'retnet_cls_conv_n1_fpn%d_b' % k_min, + 'n_conv_fpn_cls_modules.2.weight': 'retnet_cls_conv_n2_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.2.bias': 'retnet_cls_conv_n2_fpn%d_b' % k_min, + 'n_conv_fpn_cls_modules.3.weight': 'retnet_cls_conv_n3_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.3.bias': 'retnet_cls_conv_n3_fpn%d_b' % k_min, + + 'n_conv_fpn_bbox_modules.0.weight': 'retnet_bbox_conv_n0_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.0.bias': 'retnet_bbox_conv_n0_fpn%d_b' % k_min, + 'n_conv_fpn_bbox_modules.1.weight': 'retnet_bbox_conv_n1_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.1.bias': 'retnet_bbox_conv_n1_fpn%d_b' % k_min, + 'n_conv_fpn_bbox_modules.2.weight': 'retnet_bbox_conv_n2_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.2.bias': 'retnet_bbox_conv_n2_fpn%d_b' % k_min, + 'n_conv_fpn_bbox_modules.3.weight': 'retnet_bbox_conv_n3_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.3.bias': 'retnet_bbox_conv_n3_fpn%d_b' % k_min, + + 'fpn_cls_score.weight': 'retnet_cls_pred_fpn%d_w' % k_min, + 'fpn_cls_score.bias': 'retnet_cls_pred_fpn%d_b' % k_min, + 'fpn_bbox_score.weight': 'retnet_bbox_pred_fpn%d_w' % k_min, + 'fpn_bbox_score.bias': 'retnet_bbox_pred_fpn%d_b' % k_min + } + return mapping_to_detectron, [] + + def forward(self, blobs_in): + k_max = cfg.FPN.RPN_MAX_LEVEL # coarsest level of pyramid + k_min = cfg.FPN.RPN_MIN_LEVEL # finest level of pyramid + assert len(blobs_in) == k_max - k_min + 1 + bbox_feat_list = [] + cls_score = [] + bbox_pred = [] + + # ========================================================================== + # classification tower with logits and prob prediction + # ========================================================================== + for lvl in range(k_min, k_max + 1): + bl_in = blobs_in[k_max - lvl] # blobs_in is in reversed order + # classification tower stack convolution starts + for nconv in range(cfg.RETINANET.NUM_CONVS): + bl_out = self.n_conv_fpn_cls_modules[nconv](bl_in) + bl_in = F.relu(bl_out, inplace=True) + bl_feat = bl_in + + # cls tower stack convolution ends. Add the logits layer now + retnet_cls_pred = self.fpn_cls_score(bl_feat) + + if not self.training: + if cfg.RETINANET.SOFTMAX: + raise NotImplementedError("To be implemented") + else: # sigmoid + retnet_cls_probs = retnet_cls_pred.sigmoid() + cls_score.append(retnet_cls_probs) + else: + cls_score.append(retnet_cls_pred) + + if cfg.RETINANET.SHARE_CLS_BBOX_TOWER: + bbox_feat_list.append(bl_feat) + + # ========================================================================== + # bbox tower if not sharing features with the classification tower with + # logits and prob prediction + # ========================================================================== + if not cfg.RETINANET.SHARE_CLS_BBOX_TOWER: + for lvl in range(k_min, k_max + 1): + bl_in = blobs_in[k_max - lvl] # blobs_in is in reversed order + # classification tower stack convolution starts + for nconv in range(cfg.RETINANET.NUM_CONVS): + bl_out = self.n_conv_fpn_bbox_modules[nconv](bl_in) + bl_in = F.relu(bl_out, inplace=True) + # Add octave scales and aspect ratio + # At least 1 convolution for dealing different aspect ratios + bl_feat = bl_in + bbox_feat_list.append(bl_feat) + + # Depending on the features [shared/separate] for bbox, add prediction layer + for i, lvl in enumerate(range(k_min, k_max + 1)): + bl_feat = bbox_feat_list[i] + retnet_bbox_pred = self.fpn_bbox_score(bl_feat) + bbox_pred.append(retnet_bbox_pred) + + return cls_score, bbox_pred + + +def add_fpn_retinanet_losses(cls_score, bbox_pred, **kwargs): + k_max = cfg.FPN.RPN_MAX_LEVEL # coarsest level of pyramid + k_min = cfg.FPN.RPN_MIN_LEVEL # finest level of pyramid + + losses_cls = [] + losses_bbox = [] + for i, lvl in enumerate(range(k_min, k_max + 1)): + slvl = str(lvl) + h, w = cls_score[i].shape[2:] + retnet_cls_labels_fpn = kwargs['retnet_cls_labels_fpn' + + slvl][:, :, :h, :w] + retnet_bbox_targets_fpn = kwargs['retnet_roi_bbox_targets_fpn' + + slvl][:, :, :, :h, :w] + retnet_bbox_inside_weights_fpn = kwargs['retnet_bbox_inside_weights_wide_fpn' + + slvl][:, :, :, :h, :w] + retnet_fg_num = kwargs['retnet_fg_num'] + + # ========================================================================== + # bbox regression loss - SelectSmoothL1Loss for multiple anchors at a location + # ========================================================================== + bbox_loss = net_utils.select_smooth_l1_loss( + bbox_pred[i], retnet_bbox_targets_fpn, + retnet_bbox_inside_weights_fpn, + retnet_fg_num, + beta=cfg.RETINANET.BBOX_REG_BETA) + + # ========================================================================== + # cls loss - depends on softmax/sigmoid outputs + # ========================================================================== + if cfg.RETINANET.SOFTMAX: + raise NotImplementedError("To be implemented") + else: + cls_loss = net_utils.sigmoid_focal_loss( + cls_score[i], retnet_cls_labels_fpn.float(), + cfg.MODEL.NUM_CLASSES, retnet_fg_num, alpha=cfg.RETINANET.LOSS_ALPHA, + gamma=cfg.RETINANET.LOSS_GAMMA + ) + + losses_bbox.append(bbox_loss) + losses_cls.append(cls_loss) + + return losses_cls, losses_bbox diff --git a/lib/roi_data/minibatch.py b/lib/roi_data/minibatch.py index 65ef9e66..62bff744 100644 --- a/lib/roi_data/minibatch.py +++ b/lib/roi_data/minibatch.py @@ -4,6 +4,7 @@ from core.config import cfg import utils.blob as blob_utils import roi_data.rpn +import roi_data.retinanet def get_minibatch_blob_names(is_training=True): @@ -15,7 +16,9 @@ def get_minibatch_blob_names(is_training=True): # RPN-only or end-to-end Faster R-CNN blob_names += roi_data.rpn.get_rpn_blob_names(is_training=is_training) elif cfg.RETINANET.RETINANET_ON: - raise NotImplementedError + blob_names += roi_data.retinanet.get_retinanet_blob_names( + is_training=is_training + ) else: # Fast R-CNN like models trained on precomputed proposals blob_names += roi_data.fast_rcnn.get_fast_rcnn_blob_names( @@ -37,7 +40,13 @@ def get_minibatch(roidb): # RPN-only or end-to-end Faster/Mask R-CNN valid = roi_data.rpn.add_rpn_blobs(blobs, im_scales, roidb) elif cfg.RETINANET.RETINANET_ON: - raise NotImplementedError + im_width, im_height = im_blob.shape[3], im_blob.shape[2] + # im_width, im_height corresponds to the network input: padded image + # (if needed) width and height. We pass it as input and slice the data + # accordingly so that we don't need to use SampleAsOp + valid = roi_data.retinanet.add_retinanet_blobs( + blobs, im_scales, roidb, im_width, im_height + ) else: # Fast R-CNN like models trained on precomputed proposals valid = roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb) diff --git a/lib/roi_data/retinanet.py b/lib/roi_data/retinanet.py new file mode 100644 index 00000000..03b329bb --- /dev/null +++ b/lib/roi_data/retinanet.py @@ -0,0 +1,258 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +"""Compute minibatch blobs for training a RetinaNet network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import logging + +import utils.boxes as box_utils +import roi_data.data_utils as data_utils +from core.config import cfg + +logger = logging.getLogger(__name__) + + +def get_retinanet_blob_names(is_training=True): + """ + Returns blob names in the order in which they are read by the data + loader. + """ + # im_info: (height, width, image scale) + blob_names = ['im_info'] + assert cfg.FPN.FPN_ON, "RetinaNet uses FPN for dense detection" + # Same format as RPN blobs, but one per FPN level + if is_training: + blob_names += ['roidb', 'retnet_fg_num', 'retnet_bg_num'] + for lvl in range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1): + suffix = 'fpn{}'.format(lvl) + blob_names += [ + 'retnet_cls_labels_' + suffix, + 'retnet_roi_bbox_targets_' + suffix, + 'retnet_bbox_inside_weights_wide_' + suffix, + ] + return blob_names + + +def add_retinanet_blobs(blobs, im_scales, roidb, image_width, image_height): + """Add RetinaNet blobs.""" + # RetinaNet is applied to many feature levels, as in the FPN paper + k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL + scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE + num_aspect_ratios = len(cfg.RETINANET.ASPECT_RATIOS) + aspect_ratios = cfg.RETINANET.ASPECT_RATIOS + anchor_scale = cfg.RETINANET.ANCHOR_SCALE + + # get anchors from all levels for all scales/aspect ratios + foas = [] + for lvl in range(k_min, k_max + 1): + stride = 2. ** lvl + for octave in range(scales_per_octave): + octave_scale = 2 ** (octave / float(scales_per_octave)) + for idx in range(num_aspect_ratios): + anchor_sizes = (stride * octave_scale * anchor_scale,) + anchor_aspect_ratios = (aspect_ratios[idx],) + foa = data_utils.get_field_of_anchors( + stride, anchor_sizes, anchor_aspect_ratios, octave, idx) + foas.append(foa) + all_anchors = np.concatenate([f.field_of_anchors for f in foas]) + + blobs['retnet_fg_num'], blobs['retnet_bg_num'] = 0.0, 0.0 + for im_i, entry in enumerate(roidb): + scale = im_scales[im_i] + im_height = np.round(entry['height'] * scale) + im_width = np.round(entry['width'] * scale) + gt_inds = np.where( + (entry['gt_classes'] > 0) & (entry['is_crowd'] == 0))[0] + assert len(gt_inds) > 0, \ + 'Empty ground truth empty for image is not allowed. Please check.' + + gt_rois = entry['boxes'][gt_inds, :] * scale + gt_classes = entry['gt_classes'][gt_inds] + + im_info = np.array([[im_height, im_width, scale]], dtype=np.float32) + blobs['im_info'].append(im_info) + + retinanet_blobs, fg_num, bg_num = _get_retinanet_blobs( + foas, all_anchors, gt_rois, gt_classes, image_width, image_height) + for i, foa in enumerate(foas): + for k, v in retinanet_blobs[i].items(): + level = int(np.log2(foa.stride)) + key = '{}_fpn{}'.format(k, level) + blobs[key].append(v) + blobs['retnet_fg_num'] += fg_num + blobs['retnet_bg_num'] += bg_num + + blobs['retnet_fg_num'] = blobs['retnet_fg_num'].astype(np.float32) + blobs['retnet_bg_num'] = blobs['retnet_bg_num'].astype(np.float32) + + N = len(roidb) + for k, v in blobs.items(): + if isinstance(v, list) and len(v) > 0: + # compute number of anchors + A = int(len(v) / N) + # for the cls branch labels [per fpn level], + # we have blobs['retnet_cls_labels_fpn{}'] as a list until this step + # and length of this list is N x A where + # N = num_images, A = num_anchors for example, N = 2, A = 9 + # Each element of the list has the shape 1 x 1 x H x W where H, W are + # spatial dimension of curret fpn lvl. Let a{i} denote the element + # corresponding to anchor i [9 anchors total] in the list. + # The elements in the list are in order [[a0, ..., a9], [a0, ..., a9]] + # however the network will make predictions like 2 x (9 * 80) x H x W + # so we first concatenate the elements of each image to a numpy array + # and then concatenate the two images to get the 2 x 9 x H x W + + if k.find('retnet_cls_labels') >= 0 \ + or k.find('retnet_roi_bbox_targets') >= 0: + tmp = [] + # concat anchors within an image + for i in range(0, len(v), A): + tmp.append(np.concatenate(v[i: i + A], axis=1)) + # concat images + blobs[k] = np.concatenate(tmp, axis=0) + else: + # for the bbox branch elements [per FPN level], + # we have the targets and the fg boxes locations + # in the shape: M x 4 where M is the number of fg locations in a + # given image at the current FPN level. For the given level, + # the bbox predictions will be. The elements in the list are in + # order [[a0, ..., a9], [a0, ..., a9]] + # Concatenate them to form M x 4 + blobs[k] = np.expand_dims(np.concatenate(v, axis=0), axis=0) + + valid_keys = [ + 'has_visible_keypoints', 'boxes', 'segms', 'seg_areas', 'gt_classes', + 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints' + ] + minimal_roidb = [{} for _ in range(len(roidb))] + for i, e in enumerate(roidb): + for k in valid_keys: + if k in e: + minimal_roidb[i][k] = e[k] + # blobs['roidb'] = blob_utils.serialize(minimal_roidb) + blobs['roidb'] = minimal_roidb + + return True + + +def _get_retinanet_blobs( + foas, all_anchors, gt_boxes, gt_classes, im_width, im_height): + total_anchors = all_anchors.shape[0] + logger.debug('Getting mad blobs: im_height {} im_width: {}'.format( + im_height, im_width)) + + inds_inside = np.arange(all_anchors.shape[0]) + anchors = all_anchors + num_inside = len(inds_inside) + + logger.debug('total_anchors: {}'.format(total_anchors)) + logger.debug('inds_inside: {}'.format(num_inside)) + logger.debug('anchors.shape: {}'.format(anchors.shape)) + + # Compute anchor labels: + # label=1 is positive, 0 is negative, -1 is don't care (ignore) + labels = np.empty((num_inside,), dtype=np.float32) + labels.fill(-1) + if len(gt_boxes) > 0: + # Compute overlaps between the anchors and the gt boxes overlaps + anchor_by_gt_overlap = box_utils.bbox_overlaps(anchors, gt_boxes) + # Map from anchor to gt box that has highest overlap + anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1) + # For each anchor, amount of overlap with most overlapping gt box + anchor_to_gt_max = anchor_by_gt_overlap[ + np.arange(num_inside), anchor_to_gt_argmax] + + # Map from gt box to an anchor that has highest overlap + gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0) + # For each gt box, amount of overlap with most overlapping anchor + gt_to_anchor_max = anchor_by_gt_overlap[ + gt_to_anchor_argmax, np.arange(anchor_by_gt_overlap.shape[1])] + # Find all anchors that share the max overlap amount + # (this includes many ties) + anchors_with_max_overlap = np.where( + anchor_by_gt_overlap == gt_to_anchor_max)[0] + + # Fg label: for each gt use anchors with highest overlap + # (including ties) + gt_inds = anchor_to_gt_argmax[anchors_with_max_overlap] + labels[anchors_with_max_overlap] = gt_classes[gt_inds] + # Fg label: above threshold IOU + inds = anchor_to_gt_max >= cfg.RETINANET.POSITIVE_OVERLAP + gt_inds = anchor_to_gt_argmax[inds] + labels[inds] = gt_classes[gt_inds] + + fg_inds = np.where(labels >= 1)[0] + bg_inds = np.where(anchor_to_gt_max < cfg.RETINANET.NEGATIVE_OVERLAP)[0] + labels[bg_inds] = 0 + num_fg, num_bg = len(fg_inds), len(bg_inds) + + bbox_targets = np.zeros((num_inside, 4), dtype=np.float32) + bbox_targets[fg_inds, :] = data_utils.compute_targets( + anchors[fg_inds, :], gt_boxes[anchor_to_gt_argmax[fg_inds], :]) + + # Bbox regression loss has the form: + # loss(x) = weight_outside * L(weight_inside * x) + # Inside weights allow us to set zero loss on an element-wise basis + # Bbox regression is only trained on positive examples so we set their + # weights to 1.0 (or otherwise if config is different) and 0 otherwise + bbox_inside_weights = np.zeros((num_inside, 4), dtype=np.float32) + bbox_inside_weights[labels >= 1, :] = (1.0, 1.0, 1.0, 1.0) + + # Map up to original set of anchors + labels = data_utils.unmap(labels, total_anchors, inds_inside, fill=-1) + bbox_inside_weights = data_utils.unmap( + bbox_inside_weights, total_anchors, inds_inside, fill=0 + ) + bbox_targets = data_utils.unmap( + bbox_targets, total_anchors, inds_inside, fill=0) + + # Split the generated labels, etc. into labels per each field of anchors + blobs_out = [] + start_idx = 0 + for i, foa in enumerate(foas): + H = foa.field_size + W = foa.field_size + end_idx = start_idx + H * W + _labels = labels[start_idx:end_idx] + _bbox_targets = bbox_targets[start_idx:end_idx, :] + _bbox_inside_weights = bbox_inside_weights[start_idx:end_idx, :] + start_idx = end_idx + # labels output with shape (1, height, width) + _labels = _labels.reshape((1, 1, H, W)) + # bbox_targets output with shape (1, 4 * A, height, width) + _bbox_targets = _bbox_targets.reshape( + (1, 1, H, W, 4)).transpose(0, 1, 4, 2, 3) + + # bbox_inside_weights output with shape (1, 4 * A, height, width) + _bbox_inside_weights = _bbox_inside_weights.reshape( + (1, H, W, 4)).transpose(0, 3, 1, 2) + + blobs_out.append( + dict( + retnet_cls_labels=_labels.astype(np.int32), + retnet_roi_bbox_targets=_bbox_targets.astype(np.float32), + retnet_bbox_inside_weights_wide=_bbox_inside_weights + )) + out_num_fg = np.array([num_fg + 1.0], dtype=np.float32) + out_num_bg = ( + np.array([num_bg + 1.0]) * (cfg.MODEL.NUM_CLASSES - 1) + + out_num_fg * (cfg.MODEL.NUM_CLASSES - 2)) + return blobs_out, out_num_fg, out_num_bg diff --git a/lib/utils/net.py b/lib/utils/net.py index 32c5d705..3f5e7f74 100644 --- a/lib/utils/net.py +++ b/lib/utils/net.py @@ -24,7 +24,7 @@ def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_we abs_in_box_diff = torch.abs(in_box_diff) smoothL1_sign = (abs_in_box_diff < beta).detach().float() in_loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta + \ - (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta)) + (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta)) out_loss_box = bbox_outside_weights * in_loss_box loss_box = out_loss_box N = loss_box.size(0) # batch size @@ -32,6 +32,50 @@ def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_we return loss_box +def select_smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, num_fg, beta=1.0): + bbox_pred = bbox_pred.reshape\ + ((bbox_pred.shape[0], bbox_targets.shape[1], 4, bbox_pred.shape[2], bbox_pred.shape[3])) + box_diff = bbox_pred - bbox_targets + in_box_diff = bbox_inside_weights * box_diff + abs_in_box_diff = torch.abs(in_box_diff) + smoothL1_sign = (abs_in_box_diff < beta).detach().float() + loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta + \ + (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta)) + loss_box = loss_box.view(-1).sum(0) / num_fg.sum() + return loss_box + + +def sigmoid_focal_loss(cls_preds, cls_targets, num_classes, num_fg, alpha=0.25, gamma=2): + masked_cls_preds = cls_preds.reshape(( + cls_preds.size(0), cls_targets.size(1), num_classes - 1, + cls_preds.size(2), cls_preds.size(3))).permute((0, 1, 3, 4, 2)).contiguous().\ + view(-1, num_classes-1) + masked_cls_targets = cls_targets.view(-1) + + weights = (masked_cls_targets >= 0).float() + weights = weights.unsqueeze(1) + + t = masked_cls_preds.data.new( + masked_cls_preds.size(0), num_classes).fill_(0) + ids = masked_cls_targets.view(-1, 1) % (num_classes) + t.scatter_(1, ids.long(), 1.) + t = t[:, 1:] + + p = masked_cls_preds.sigmoid() + # w = alpha if t > 0 else 1-alpha + alpha_factor = alpha * t + (1 - alpha) * (1 - t) + # pt = p if t > 0 else 1-p + focal_weight = p * t + (1 - p) * (1 - t) + focal_weight = alpha_factor * (1 - focal_weight).pow(gamma) + + cls_loss = focal_weight * \ + F.binary_cross_entropy_with_logits( + masked_cls_preds, t, weight=weights, reduction='none') + cls_loss = cls_loss.sum() / num_fg.sum() + + return cls_loss + + def clip_gradient(model, clip_norm): """Computes a gradient clipping coefficient based on gradient norm.""" totalnorm = 0 @@ -62,7 +106,9 @@ def decay_learning_rate(optimizer, cur_lr, decay_rate): if cfg.SOLVER.TYPE in ['SGD']: if cfg.SOLVER.SCALE_MOMENTUM and cur_lr > 1e-7 and \ ratio > cfg.SOLVER.SCALE_MOMENTUM_THRESHOLD: - _CorrectMomentum(optimizer, param_group['params'], new_lr / cur_lr) + _CorrectMomentum( + optimizer, param_group['params'], new_lr / cur_lr) + def update_learning_rate(optimizer, cur_lr, new_lr): """Update learning rate""" @@ -119,15 +165,16 @@ def affine_grid_gen(rois, input_size, grid_size): width = input_size[1] zero = Variable(rois.data.new(rois.size(0), 1).zero_()) - theta = torch.cat([\ - (x2 - x1) / (width - 1), - zero, - (x1 + x2 - width + 1) / (width - 1), - zero, - (y2 - y1) / (height - 1), - (y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3) + theta = torch.cat([ + (x2 - x1) / (width - 1), + zero, + (x1 + x2 - width + 1) / (width - 1), + zero, + (y2 - y1) / (height - 1), + (y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3) - grid = F.affine_grid(theta, torch.Size((rois.size(0), 1, grid_size, grid_size))) + grid = F.affine_grid(theta, torch.Size( + (rois.size(0), 1, grid_size, grid_size))) return grid @@ -139,7 +186,8 @@ def save_ckpt(output_dir, args, model, optimizer): ckpt_dir = os.path.join(output_dir, 'ckpt') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) - save_name = os.path.join(ckpt_dir, 'model_{}_{}.pth'.format(args.epoch, args.step)) + save_name = os.path.join( + ckpt_dir, 'model_{}_{}.pth'.format(args.epoch, args.step)) if isinstance(model, mynn.DataParallel): model = model.module # TODO: (maybe) Do not save redundant shared params diff --git a/lib/utils/training_stats.py b/lib/utils/training_stats.py index 8cae8a18..610c4304 100644 --- a/lib/utils/training_stats.py +++ b/lib/utils/training_stats.py @@ -83,24 +83,24 @@ def UpdateIterStats(self, model_out, inner_iter=None): assert loss.shape[0] == cfg.NUM_GPUS loss = loss.mean(dim=0, keepdim=True) total_loss += loss - loss_data = loss.data[0] + loss_data = loss.data[0].item() model_out['losses'][k] = loss if cfg.FPN.FPN_ON: - if k.startswith('loss_rpn_cls_'): + if k.startswith('loss_rpn_cls_') or k.startswith('loss_retnet_cls_'): loss_rpn_cls_data += loss_data - elif k.startswith('loss_rpn_bbox_'): + elif k.startswith('loss_rpn_bbox_') or k.startswith('loss_retnet_bbox_'): loss_rpn_bbox_data += loss_data self.smoothed_losses[k].AddValue(loss_data) model_out['total_loss'] = total_loss # Add the total loss for back propagation - self.smoothed_total_loss.AddValue(total_loss.data[0]) + self.smoothed_total_loss.AddValue(total_loss.data[0].item()) if cfg.FPN.FPN_ON: self.smoothed_losses['loss_rpn_cls'].AddValue(loss_rpn_cls_data) self.smoothed_losses['loss_rpn_bbox'].AddValue(loss_rpn_bbox_data) for k, metric in model_out['metrics'].items(): metric = metric.mean(dim=0, keepdim=True) - self.smoothed_metrics[k].AddValue(metric.data[0]) + self.smoothed_metrics[k].AddValue(metric.data[0].item()) def _UpdateIterStats_inner(self, model_out, inner_iter): """Update tracked iteration statistics for the case of iter_size > 1""" @@ -125,13 +125,13 @@ def _UpdateIterStats_inner(self, model_out, inner_iter): assert loss.shape[0] == cfg.NUM_GPUS loss = loss.mean(dim=0, keepdim=True) total_loss += loss - loss_data = loss.data[0] + loss_data = loss.data[0].item() model_out['losses'][k] = loss if cfg.FPN.FPN_ON: - if k.startswith('loss_rpn_cls_'): + if k.startswith('loss_rpn_cls_') or k.startswith('loss_retnet_cls_'): loss_rpn_cls_data += loss_data - elif k.startswith('loss_rpn_bbox_'): + elif k.startswith('loss_rpn_bbox_') or k.startswith('loss_retnet_bbox_'): loss_rpn_bbox_data += loss_data self.inner_losses[k].append(loss_data) @@ -140,7 +140,7 @@ def _UpdateIterStats_inner(self, model_out, inner_iter): self.smoothed_losses[k].AddValue(loss_data) model_out['total_loss'] = total_loss # Add the total loss for back propagation - total_loss_data = total_loss.data[0] + total_loss_data = total_loss.data[0].item() self.inner_total_loss.append(total_loss_data) if cfg.FPN.FPN_ON: self.inner_loss_rpn_cls.append(loss_rpn_cls_data) @@ -156,7 +156,7 @@ def _UpdateIterStats_inner(self, model_out, inner_iter): for k, metric in model_out['metrics'].items(): metric = metric.mean(dim=0, keepdim=True) - metric_data = metric.data[0] + metric_data = metric.data[0].item() self.inner_metrics[k].append(metric_data) if inner_iter == (self.misc_args.iter_size - 1): metric_data = self._mean_and_reset_inner_list('inner_metrics', k)