diff --git a/BENCHMARK.md b/BENCHMARK.md
index 5af7fc56..d0bbd63e 100644
--- a/BENCHMARK.md
+++ b/BENCHMARK.md
@@ -29,6 +29,72 @@
ARl | AR for large objects: area > 962 |
+## RetinaNet
+### retinanet-R-50-FPN_1x
+
+- Training command:
+
+ ```
+ python tools/train_net_step.py \
+ --dataset coco2017 --cfg configs/baselines/retinanet_R-50-FPN_1x.yaml \
+ --bs 8 --iter_size 1 --use_tfboard
+ ```
+ on four V100 GPUs.
+
+
+Box |
+
+source |
+AP50:95 |
+AP50 |
+AP75 |
+APs |
+APm |
+APl |
+AR1 |
+AR10 |
+AR100 |
+ARs |
+ARm |
+ARl |
+
+
+PyTorch |
+35.3 |
+54.6 |
+37.9 |
+19.4 |
+39.1 |
+47.5 |
+30.7 |
+48.9 |
+51.8 |
+32.4 |
+56.3 |
+67.4 |
+
+
+Detectron |
+35.7 |
+54.7 |
+38.5 |
+19.5 |
+39.9 |
+47.5 |
+30.7 |
+49.1 |
+52.0 |
+32.0 |
+56.9 |
+68.0 |
+
+
+
+- Total loss comparison:
+
+ 
+
+
## Faster-RCNN
### e2e_faster_rcnn-R-50-FPN_1x
diff --git a/configs/baselines/retinanet_R-50-FPN_1x.yaml b/configs/baselines/retinanet_R-50-FPN_1x.yaml
new file mode 100644
index 00000000..989807ac
--- /dev/null
+++ b/configs/baselines/retinanet_R-50-FPN_1x.yaml
@@ -0,0 +1,41 @@
+DEBUG: False
+MODEL:
+ TYPE: retinanet
+ CONV_BODY: FPN.fpn_ResNet50_conv5_body
+ NUM_CLASSES: 81
+RESNETS:
+ IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_networks/resnet50_caffe.pth'
+NUM_GPUS: 8
+SOLVER:
+ WEIGHT_DECAY: 0.0001
+ LR_POLICY: steps_with_decay
+ BASE_LR: 0.01
+ GAMMA: 0.1
+ MAX_ITER: 90000
+ STEPS: [0, 60000, 80000]
+FPN:
+ FPN_ON: True
+ MULTILEVEL_RPN: True
+ RPN_MAX_LEVEL: 7
+ RPN_MIN_LEVEL: 3
+ COARSEST_STRIDE: 128
+ EXTRA_CONV_LEVELS: True
+RETINANET:
+ RETINANET_ON: True
+ NUM_CONVS: 4
+ ASPECT_RATIOS: (1.0, 2.0, 0.5)
+ SCALES_PER_OCTAVE: 3
+ ANCHOR_SCALE: 4
+ LOSS_GAMMA: 2.0
+ LOSS_ALPHA: 0.25
+TRAIN:
+ SCALES: (800,)
+ MAX_SIZE: 1333
+ RPN_STRADDLE_THRESH: -1 # default 0
+TEST:
+ SCALE: 800
+ MAX_SIZE: 1333
+ NMS: 0.5
+ RPN_PRE_NMS_TOP_N: 10000 # Per FPN level
+ RPN_POST_NMS_TOP_N: 2000
+OUTPUT_DIR: .
\ No newline at end of file
diff --git a/demo/loss_retinanet_R-50-FPN_1x.jpg b/demo/loss_retinanet_R-50-FPN_1x.jpg
new file mode 100644
index 00000000..921df780
Binary files /dev/null and b/demo/loss_retinanet_R-50-FPN_1x.jpg differ
diff --git a/lib/core/test.py b/lib/core/test.py
index eac67355..0dff62e3 100644
--- a/lib/core/test.py
+++ b/lib/core/test.py
@@ -45,6 +45,7 @@
import utils.fpn as fpn_utils
import utils.image as image_utils
import utils.keypoints as keypoint_utils
+import core.test_retinanet as test_retinanet
def im_detect_all(model, im, box_proposals=None, timers=None):
@@ -62,6 +63,13 @@ def im_detect_all(model, im, box_proposals=None, timers=None):
if timers is None:
timers = defaultdict(Timer)
+ # Handle RetinaNet testing separately for now
+ if cfg.RETINANET.RETINANET_ON:
+ timers['im_detect_bbox'].tic()
+ cls_boxes = test_retinanet.im_detect_bbox(model, im, timers)
+ timers['im_detect_bbox'].toc()
+ return cls_boxes, None, None
+
timers['im_detect_bbox'].tic()
if cfg.TEST.BBOX_AUG.ENABLED:
scores, boxes, im_scale, blob_conv = im_detect_bbox_aug(
diff --git a/lib/core/test_retinanet.py b/lib/core/test_retinanet.py
new file mode 100644
index 00000000..6d2f1dc9
--- /dev/null
+++ b/lib/core/test_retinanet.py
@@ -0,0 +1,196 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+"""Test a RetinaNet network on an image database"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import logging
+from collections import defaultdict
+
+from torch.autograd import Variable
+import torch
+
+from core.config import cfg
+from modeling.generate_anchors import generate_anchors
+from utils.timer import Timer
+import utils.blob as blob_utils
+import utils.boxes as box_utils
+import roi_data.data_utils as data_utils
+
+logger = logging.getLogger(__name__)
+
+
+def _create_cell_anchors():
+ """
+ Generate all types of anchors for all fpn levels/scales/aspect ratios.
+ This function is called only once at the beginning of inference.
+ """
+ k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
+ scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE
+ aspect_ratios = cfg.RETINANET.ASPECT_RATIOS
+ anchor_scale = cfg.RETINANET.ANCHOR_SCALE
+ A = scales_per_octave * len(aspect_ratios)
+
+ anchors = {}
+ for lvl in range(k_min, k_max + 1):
+ # create cell anchors array
+ stride = 2. ** lvl
+ cell_anchors = np.zeros((A, 4))
+ a = 0
+ for octave in range(scales_per_octave):
+ octave_scale = 2 ** (octave / float(scales_per_octave))
+ for aspect in aspect_ratios:
+ anchor_sizes = (stride * octave_scale * anchor_scale, )
+ anchor_aspect_ratios = (aspect, )
+ cell_anchors[a, :] = generate_anchors(
+ stride=stride, sizes=anchor_sizes,
+ aspect_ratios=anchor_aspect_ratios)
+ a += 1
+ anchors[lvl] = cell_anchors
+ return anchors
+
+
+def im_detect_bbox(model, im, timers=None):
+ """Generate RetinaNet detections on a single image."""
+ if timers is None:
+ timers = defaultdict(Timer)
+ # Although anchors are input independent and could be precomputed,
+ # recomputing them per image only brings a small overhead
+ anchors = _create_cell_anchors()
+ timers['im_detect_bbox'].tic()
+ k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
+ A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
+ inputs = {}
+ inputs['data'], im_scale, inputs['im_info'] = \
+ blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
+
+ if cfg.PYTORCH_VERSION_LESS_THAN_040:
+ inputs['data'] = [
+ Variable(torch.from_numpy(inputs['data']), volatile=True)]
+ inputs['im_info'] = [
+ Variable(torch.from_numpy(inputs['im_info']), volatile=True)]
+ else:
+ inputs['data'] = [torch.from_numpy(inputs['data'])]
+ inputs['im_info'] = [torch.from_numpy(inputs['im_info'])]
+
+ return_dict = model(**inputs)
+ cls_probs = return_dict['cls_score']
+ box_preds = return_dict['bbox_pred']
+
+ # here the boxes_all are [x0, y0, x1, y1, score]
+ boxes_all = defaultdict(list)
+
+ cnt = 0
+ for lvl in range(k_min, k_max + 1):
+ # create cell anchors array
+ stride = 2. ** lvl
+ cell_anchors = anchors[lvl]
+
+ # fetch per level probability
+ cls_prob = cls_probs[cnt].data.cpu().numpy()
+ box_pred = box_preds[cnt].data.cpu().numpy()
+ cls_prob = cls_prob.reshape((
+ cls_prob.shape[0], A, int(cls_prob.shape[1] / A),
+ cls_prob.shape[2], cls_prob.shape[3]))
+ box_pred = box_pred.reshape((
+ box_pred.shape[0], A, 4, box_pred.shape[2], box_pred.shape[3]))
+ cnt += 1
+
+ if cfg.RETINANET.SOFTMAX:
+ cls_prob = cls_prob[:, :, 1::, :, :]
+
+ cls_prob_ravel = cls_prob.ravel()
+ # In some cases [especially for very small img sizes], it's possible that
+ # candidate_ind is empty if we impose threshold 0.05 at all levels. This
+ # will lead to errors since no detections are found for this image. Hence,
+ # for lvl 7 which has small spatial resolution, we take the threshold 0.0
+ th = cfg.RETINANET.INFERENCE_TH if lvl < k_max else 0.0
+ candidate_inds = np.where(cls_prob_ravel > th)[0]
+ if (len(candidate_inds) == 0):
+ continue
+
+ pre_nms_topn = min(cfg.RETINANET.PRE_NMS_TOP_N, len(candidate_inds))
+ inds = np.argpartition(
+ cls_prob_ravel[candidate_inds], -pre_nms_topn)[-pre_nms_topn:]
+ inds = candidate_inds[inds]
+
+ inds_5d = np.array(np.unravel_index(inds, cls_prob.shape)).transpose()
+ classes = inds_5d[:, 2]
+ anchor_ids, y, x = inds_5d[:, 1], inds_5d[:, 3], inds_5d[:, 4]
+ scores = cls_prob[:, anchor_ids, classes, y, x]
+
+ boxes = np.column_stack((x, y, x, y)).astype(dtype=np.float32)
+ boxes *= stride
+ boxes += cell_anchors[anchor_ids, :]
+
+ if not cfg.RETINANET.CLASS_SPECIFIC_BBOX:
+ box_deltas = box_pred[0, anchor_ids, :, y, x]
+ else:
+ box_cls_inds = classes * 4
+ box_deltas = np.vstack(
+ [box_pred[0, ind:ind + 4, yi, xi]
+ for ind, yi, xi in zip(box_cls_inds, y, x)]
+ )
+ pred_boxes = (
+ box_utils.bbox_transform(boxes, box_deltas)
+ if cfg.TEST.BBOX_REG else boxes)
+ pred_boxes /= im_scale
+ pred_boxes = box_utils.clip_tiled_boxes(pred_boxes, im.shape)
+ box_scores = np.zeros((pred_boxes.shape[0], 5))
+ box_scores[:, 0:4] = pred_boxes
+ box_scores[:, 4] = scores
+
+ for cls in range(1, cfg.MODEL.NUM_CLASSES):
+ inds = np.where(classes == cls - 1)[0]
+ if len(inds) > 0:
+ boxes_all[cls].extend(box_scores[inds, :])
+ timers['im_detect_bbox'].toc()
+
+ # Combine predictions across all levels and retain the top scoring by class
+ timers['misc_bbox'].tic()
+ detections = []
+ for cls, boxes in boxes_all.items():
+ cls_dets = np.vstack(boxes).astype(dtype=np.float32)
+ # do class specific nms here
+ keep = box_utils.nms(cls_dets, cfg.TEST.NMS)
+ cls_dets = cls_dets[keep, :]
+ out = np.zeros((len(keep), 6))
+ out[:, 0:5] = cls_dets
+ out[:, 5].fill(cls)
+ detections.append(out)
+
+ # detections (N, 6) format:
+ # detections[:, :4] - boxes
+ # detections[:, 4] - scores
+ # detections[:, 5] - classes
+ detections = np.vstack(detections)
+ # sort all again
+ inds = np.argsort(-detections[:, 4])
+ detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :]
+
+ # Convert the detections to image cls_ format (see core/test_engine.py)
+ num_classes = cfg.MODEL.NUM_CLASSES
+ cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
+ for c in range(1, num_classes):
+ inds = np.where(detections[:, 5] == c)[0]
+ cls_boxes[c] = detections[inds, :5]
+ timers['misc_bbox'].toc()
+
+ return cls_boxes
diff --git a/lib/modeling/FPN.py b/lib/modeling/FPN.py
index 03cd60bc..a551699c 100644
--- a/lib/modeling/FPN.py
+++ b/lib/modeling/FPN.py
@@ -140,7 +140,7 @@ def __init__(self, conv_body_func, fpn_level_info, P2only=False):
self.extra_pyramid_modules = nn.ModuleList()
dim_in = fpn_level_info.dims[0]
for i in range(HIGHEST_BACKBONE_LVL + 1, max_level + 1):
- self.extra_pyramid_modules(
+ self.extra_pyramid_modules.append(
nn.Conv2d(dim_in, fpn_dim, 3, 2, 1)
)
dim_in = fpn_dim
@@ -214,7 +214,7 @@ def detectron_weight_mapping(self):
})
if hasattr(self, 'extra_pyramid_modules'):
- for i in len(self.extra_pyramid_modules):
+ for i in range(len(self.extra_pyramid_modules)):
p_prefix = 'extra_pyramid_modules.%d' % i
d_prefix = 'fpn_%d' % (HIGHEST_BACKBONE_LVL + 1 + i)
mapping_to_detectron.update({
@@ -246,9 +246,9 @@ def forward(self, x):
if hasattr(self, 'extra_pyramid_modules'):
blob_in = conv_body_blobs[-1]
- fpn_output_blobs.insert(0, self.extra_pyramid_modules(blob_in))
+ fpn_output_blobs.insert(0, self.extra_pyramid_modules[0](blob_in))
for module in self.extra_pyramid_modules[1:]:
- fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0], inplace=True)))
+ fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0])))
if self.P2only:
# use only the finest level
@@ -294,7 +294,7 @@ def forward(self, top_blob, lateral_blob):
lat = self.conv_lateral(lateral_blob)
# Top-down 2x upsampling
# td = F.upsample(top_blob, size=lat.size()[2:], mode='bilinear')
- td = F.upsample(top_blob, scale_factor=2, mode='nearest')
+ td = F.interpolate(top_blob, scale_factor=2, mode='nearest')
# Sum lateral and top-down
return lat + td
diff --git a/lib/modeling/model_builder.py b/lib/modeling/model_builder.py
index 0c7f1f49..59a274ef 100644
--- a/lib/modeling/model_builder.py
+++ b/lib/modeling/model_builder.py
@@ -12,6 +12,7 @@
from model.roi_crop.functions.roi_crop import RoICropFunction
from modeling.roi_xfrom.roi_align.functions.roi_align import RoIAlignFunction
import modeling.rpn_heads as rpn_heads
+import modeling.retinanet_heads as retinanet_heads
import modeling.fast_rcnn_heads as fast_rcnn_heads
import modeling.mask_rcnn_heads as mask_rcnn_heads
import modeling.keypoint_rcnn_heads as keypoint_rcnn_heads
@@ -62,8 +63,9 @@ def wrapper(self, *args, **kwargs):
with torch.no_grad():
return net_func(self, *args, **kwargs)
else:
- raise ValueError('You should call this function only on inference.'
- 'Set the network in inference mode by net.eval().')
+ raise ValueError(
+ 'You should call this function only on inference.'
+ 'Set the network in inference mode by net.eval().')
return wrapper
@@ -84,42 +86,59 @@ def __init__(self):
self.RPN = rpn_heads.generic_rpn_outputs(
self.Conv_Body.dim_out, self.Conv_Body.spatial_scale)
- if cfg.FPN.FPN_ON:
- # Only supports case when RPN and ROI min levels are the same
- assert cfg.FPN.RPN_MIN_LEVEL == cfg.FPN.ROI_MIN_LEVEL
- # RPN max level can be >= to ROI max level
- assert cfg.FPN.RPN_MAX_LEVEL >= cfg.FPN.ROI_MAX_LEVEL
- # FPN RPN max level might be > FPN ROI max level in which case we
- # need to discard some leading conv blobs (blobs are ordered from
- # max/coarsest level to min/finest level)
- self.num_roi_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1
-
- # Retain only the spatial scales that will be used for RoI heads. `Conv_Body.spatial_scale`
- # may include extra scales that are used for RPN proposals, but not for RoI heads.
- self.Conv_Body.spatial_scale = self.Conv_Body.spatial_scale[-self.num_roi_levels:]
+ if cfg.FPN.FPN_ON:
+ # Only supports case when RPN and ROI min levels are the same
+ assert cfg.FPN.RPN_MIN_LEVEL == cfg.FPN.ROI_MIN_LEVEL
+ # RPN max level can be >= to ROI max level
+ assert cfg.FPN.RPN_MAX_LEVEL >= cfg.FPN.ROI_MAX_LEVEL
+ # FPN RPN max level might be > FPN ROI max level in which case we
+ # need to discard some leading conv blobs (blobs are ordered from
+ # max/coarsest level to min/finest level)
+ self.num_roi_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1
+
+ # Retain only the spatial scales that will be used for RoI heads. `Conv_Body.spatial_scale`
+ # may include extra scales that are used for RPN proposals, but
+ # not for RoI heads.
+ self.Conv_Body.spatial_scale = self.Conv_Body.spatial_scale[-self.num_roi_levels:]
# BBOX Branch
if not cfg.MODEL.RPN_ONLY:
- self.Box_Head = get_func(cfg.FAST_RCNN.ROI_BOX_HEAD)(
- self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale)
- self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs(
- self.Box_Head.dim_out)
+ if cfg.FAST_RCNN.ROI_BOX_HEAD is '':
+ # RetinaNet
+ self.Box_Outs = retinanet_heads.fpn_retinanet_outputs(
+ self.Conv_Body.dim_out, self.Conv_Body.spatial_scale)
+ else:
+ self.Box_Head = get_func(
+ cfg.FAST_RCNN.ROI_BOX_HEAD)(
+ self.RPN.dim_out,
+ self.roi_feature_transform,
+ self.Conv_Body.spatial_scale)
+ self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs(
+ self.Box_Head.dim_out)
# Mask Branch
if cfg.MODEL.MASK_ON:
- self.Mask_Head = get_func(cfg.MRCNN.ROI_MASK_HEAD)(
- self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale)
+ self.Mask_Head = get_func(
+ cfg.MRCNN.ROI_MASK_HEAD)(
+ self.RPN.dim_out,
+ self.roi_feature_transform,
+ self.Conv_Body.spatial_scale)
if getattr(self.Mask_Head, 'SHARE_RES5', False):
self.Mask_Head.share_res5_module(self.Box_Head.res5)
- self.Mask_Outs = mask_rcnn_heads.mask_rcnn_outputs(self.Mask_Head.dim_out)
+ self.Mask_Outs = mask_rcnn_heads.mask_rcnn_outputs(
+ self.Mask_Head.dim_out)
# Keypoints Branch
if cfg.MODEL.KEYPOINTS_ON:
- self.Keypoint_Head = get_func(cfg.KRCNN.ROI_KEYPOINTS_HEAD)(
- self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale)
+ self.Keypoint_Head = get_func(
+ cfg.KRCNN.ROI_KEYPOINTS_HEAD)(
+ self.RPN.dim_out,
+ self.roi_feature_transform,
+ self.Conv_Body.spatial_scale)
if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
self.Keypoint_Head.share_res5_module(self.Box_Head.res5)
- self.Keypoint_Outs = keypoint_rcnn_heads.keypoint_outputs(self.Keypoint_Head.dim_out)
+ self.Keypoint_Outs = keypoint_rcnn_heads.keypoint_outputs(
+ self.Keypoint_Head.dim_out)
self._init_modules()
@@ -127,10 +146,16 @@ def _init_modules(self):
if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
resnet_utils.load_pretrained_imagenet_weights(self)
# Check if shared weights are equaled
- if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5', False):
- assert compare_state_dict(self.Mask_Head.res5.state_dict(), self.Box_Head.res5.state_dict())
- if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head, 'SHARE_RES5', False):
- assert compare_state_dict(self.Keypoint_Head.res5.state_dict(), self.Box_Head.res5.state_dict())
+ if cfg.MODEL.MASK_ON and getattr(
+ self.Mask_Head, 'SHARE_RES5', False):
+ assert compare_state_dict(
+ self.Mask_Head.res5.state_dict(),
+ self.Box_Head.res5.state_dict())
+ if cfg.MODEL.KEYPOINTS_ON and getattr(
+ self.Keypoint_Head, 'SHARE_RES5', False):
+ assert compare_state_dict(
+ self.Keypoint_Head.res5.state_dict(),
+ self.Box_Head.res5.state_dict())
if cfg.TRAIN.FREEZE_CONV_BODY:
for p in self.Conv_Body.parameters():
@@ -154,25 +179,30 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
blob_conv = self.Conv_Body(im_data)
- rpn_ret = self.RPN(blob_conv, im_info, roidb)
-
# if self.training:
# # can be used to infer fg/bg ratio
# return_dict['rois_label'] = rpn_ret['labels_int32']
- if cfg.FPN.FPN_ON:
+ if cfg.RPN.RPN_ON:
+ rpn_ret = self.RPN(blob_conv, im_info, roidb)
+
+ if cfg.FPN.FPN_ON and cfg.FAST_RCNN.ROI_BOX_HEAD is not '':
# Retain only the blobs that will be used for RoI heads. `blob_conv` may include
- # extra blobs that are used for RPN proposals, but not for RoI heads.
+ # extra blobs that are used for RPN proposals, but not for RoI
+ # heads.
blob_conv = blob_conv[-self.num_roi_levels:]
if not self.training:
return_dict['blob_conv'] = blob_conv
if not cfg.MODEL.RPN_ONLY:
- if cfg.MODEL.SHARE_RES5 and self.training:
- box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
+ if cfg.FAST_RCNN.ROI_BOX_HEAD is not '':
+ if cfg.MODEL.SHARE_RES5 and self.training:
+ box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
+ else:
+ box_feat = self.Box_Head(blob_conv, rpn_ret)
else:
- box_feat = self.Box_Head(blob_conv, rpn_ret)
+ box_feat = blob_conv
cls_score, bbox_pred = self.Box_Outs(box_feat)
else:
# TODO: complete the returns for RPN only situation
@@ -181,46 +211,62 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
if self.training:
return_dict['losses'] = {}
return_dict['metrics'] = {}
- # rpn loss
- rpn_kwargs.update(dict(
- (k, rpn_ret[k]) for k in rpn_ret.keys()
- if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
- ))
- loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
- if cfg.FPN.FPN_ON:
- for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
- return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
- return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
- else:
- return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
- return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox
- # bbox loss
- loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
- cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
- rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
- return_dict['losses']['loss_cls'] = loss_cls
- return_dict['losses']['loss_bbox'] = loss_bbox
- return_dict['metrics']['accuracy_cls'] = accuracy_cls
+ if cfg.FAST_RCNN.ROI_BOX_HEAD is not '':
+ # rpn loss
+ rpn_kwargs.update(dict((k, rpn_ret[k]) for k in rpn_ret.keys() if (
+ k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))))
+ loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
+ **rpn_kwargs)
+ if cfg.FPN.FPN_ON:
+ for i, lvl in enumerate(
+ range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
+ return_dict['losses']['loss_rpn_cls_fpn%d' %
+ lvl] = loss_rpn_cls[i]
+ return_dict['losses']['loss_rpn_bbox_fpn%d' %
+ lvl] = loss_rpn_bbox[i]
+ else:
+ return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
+ return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox
+
+ # bbox loss
+ loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
+ cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
+ rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
+ return_dict['losses']['loss_cls'] = loss_cls
+ return_dict['losses']['loss_bbox'] = loss_bbox
+ return_dict['metrics']['accuracy_cls'] = accuracy_cls
+
+ if cfg.RETINANET.RETINANET_ON:
+ loss_retnet_cls, loss_retnet_bbox = retinanet_heads.add_fpn_retinanet_losses(
+ cls_score, bbox_pred, **rpn_kwargs)
+ for i, lvl in enumerate(
+ range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
+ return_dict['losses']['loss_retnet_cls_fpn%d' %
+ lvl] = loss_retnet_cls[i]
+ return_dict['losses']['loss_retnet_bbox_fpn%d' %
+ lvl] = loss_retnet_bbox[i]
if cfg.MODEL.MASK_ON:
if getattr(self.Mask_Head, 'SHARE_RES5', False):
- mask_feat = self.Mask_Head(res5_feat, rpn_ret,
- roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
+ mask_feat = self.Mask_Head(
+ res5_feat, rpn_ret, roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
else:
mask_feat = self.Mask_Head(blob_conv, rpn_ret)
mask_pred = self.Mask_Outs(mask_feat)
# return_dict['mask_pred'] = mask_pred
# mask loss
- loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
+ loss_mask = mask_rcnn_heads.mask_rcnn_losses(
+ mask_pred, rpn_ret['masks_int32'])
return_dict['losses']['loss_mask'] = loss_mask
if cfg.MODEL.KEYPOINTS_ON:
if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
# No corresponding keypoint head implemented yet (Neither in Detectron)
- # Also, rpn need to generate the label 'roi_has_keypoints_int32'
- kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
- roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
+ # Also, rpn need to generate the label
+ # 'roi_has_keypoints_int32'
+ kps_feat = self.Keypoint_Head(
+ res5_feat, rpn_ret, roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
else:
kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
kps_pred = self.Keypoint_Outs(kps_feat)
@@ -243,14 +289,22 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
else:
# Testing
- return_dict['rois'] = rpn_ret['rois']
+ if cfg.FAST_RCNN.ROI_BOX_HEAD is not '':
+ return_dict['rois'] = rpn_ret['rois']
return_dict['cls_score'] = cls_score
return_dict['bbox_pred'] = bbox_pred
return return_dict
- def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoIPoolF',
- resolution=7, spatial_scale=1. / 16., sampling_ratio=0):
+ def roi_feature_transform(
+ self,
+ blobs_in,
+ rpn_ret,
+ blob_rois='rois',
+ method='RoIPoolF',
+ resolution=7,
+ spatial_scale=1. / 16.,
+ sampling_ratio=0):
"""Add the specified RoI pooling method. The sampling_ratio argument
is supported for some, but not all, RoI transform methods.
@@ -273,12 +327,16 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI
sc = spatial_scale[k_max - lvl] # in reversed order
bl_rois = blob_rois + '_fpn' + str(lvl)
if len(rpn_ret[bl_rois]):
- rois = Variable(torch.from_numpy(rpn_ret[bl_rois])).cuda(device_id)
+ rois = Variable(torch.from_numpy(
+ rpn_ret[bl_rois])).cuda(device_id)
if method == 'RoIPoolF':
- # Warning!: Not check if implementation matches Detectron
- xform_out = RoIPoolFunction(resolution, resolution, sc)(bl_in, rois)
+ # Warning!: Not check if implementation matches
+ # Detectron
+ xform_out = RoIPoolFunction(
+ resolution, resolution, sc)(bl_in, rois)
elif method == 'RoICrop':
- # Warning!: Not check if implementation matches Detectron
+ # Warning!: Not check if implementation matches
+ # Detectron
grid_xy = net_utils.affine_grid_gen(
rois, bl_in.size()[2:], self.grid_size)
grid_yx = torch.stack(
@@ -288,7 +346,8 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI
xform_out = F.max_pool2d(xform_out, 2, 2)
elif method == 'RoIAlign':
xform_out = RoIAlignFunction(
- resolution, resolution, sc, sampling_ratio)(bl_in, rois)
+ resolution, resolution, sc, sampling_ratio)(
+ bl_in, rois)
bl_out_list.append(xform_out)
# The pooled features from all levels are concatenated along the
@@ -299,7 +358,10 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI
device_id = xform_shuffled.get_device()
restore_bl = rpn_ret[blob_rois + '_idx_restore_int32']
restore_bl = Variable(
- torch.from_numpy(restore_bl.astype('int64', copy=False))).cuda(device_id)
+ torch.from_numpy(
+ restore_bl.astype(
+ 'int64',
+ copy=False))).cuda(device_id)
xform_out = xform_shuffled[restore_bl]
else:
# Single feature level
@@ -307,11 +369,14 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI
# (batch_idx, x1, y1, x2, y2) specifying an image batch index and a
# rectangle (x1, y1, x2, y2)
device_id = blobs_in.get_device()
- rois = Variable(torch.from_numpy(rpn_ret[blob_rois])).cuda(device_id)
+ rois = Variable(torch.from_numpy(
+ rpn_ret[blob_rois])).cuda(device_id)
if method == 'RoIPoolF':
- xform_out = RoIPoolFunction(resolution, resolution, spatial_scale)(blobs_in, rois)
+ xform_out = RoIPoolFunction(
+ resolution, resolution, spatial_scale)(blobs_in, rois)
elif method == 'RoICrop':
- grid_xy = net_utils.affine_grid_gen(rois, blobs_in.size()[2:], self.grid_size)
+ grid_xy = net_utils.affine_grid_gen(
+ rois, blobs_in.size()[2:], self.grid_size)
grid_yx = torch.stack(
[grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous()
xform_out = RoICropFunction()(blobs_in, Variable(grid_yx).detach())
@@ -319,7 +384,12 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI
xform_out = F.max_pool2d(xform_out, 2, 2)
elif method == 'RoIAlign':
xform_out = RoIAlignFunction(
- resolution, resolution, spatial_scale, sampling_ratio)(blobs_in, rois)
+ resolution,
+ resolution,
+ spatial_scale,
+ sampling_ratio)(
+ blobs_in,
+ rois)
return xform_out
@@ -329,7 +399,8 @@ def convbody_net(self, data):
blob_conv = self.Conv_Body(data)
if cfg.FPN.FPN_ON:
# Retain only the blobs that will be used for RoI heads. `blob_conv` may include
- # extra blobs that are used for RPN proposals, but not for RoI heads.
+ # extra blobs that are used for RPN proposals, but not for RoI
+ # heads.
blob_conv = blob_conv[-self.num_roi_levels:]
return blob_conv
diff --git a/lib/modeling/retinanet_heads.py b/lib/modeling/retinanet_heads.py
new file mode 100644
index 00000000..a4cdbb81
--- /dev/null
+++ b/lib/modeling/retinanet_heads.py
@@ -0,0 +1,207 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+"""RetinaNet model heads and losses. See: https://arxiv.org/abs/1708.02002."""
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+import utils.net as net_utils
+import math
+
+from core.config import cfg
+
+
+class fpn_retinanet_outputs(nn.Module):
+ """Add RetinaNet on FPN specific outputs."""
+
+ def __init__(self, dim_in, spatial_scales):
+ super().__init__()
+ self.dim_out = dim_in
+ self.dim_in = dim_in
+ self.spatial_scales = spatial_scales
+ self.dim_out = self.dim_in
+ self.num_anchors = len(cfg.RETINANET.ASPECT_RATIOS) * \
+ cfg.RETINANET.SCALES_PER_OCTAVE
+
+ # Create conv ops shared by all FPN levels
+ self.n_conv_fpn_cls_modules = nn.ModuleList()
+ self.n_conv_fpn_bbox_modules = nn.ModuleList()
+ for nconv in range(cfg.RETINANET.NUM_CONVS):
+ self.n_conv_fpn_cls_modules.append(
+ nn.Conv2d(self.dim_in, self.dim_out, 3, 1, 1))
+ self.n_conv_fpn_bbox_modules.append(
+ nn.Conv2d(self.dim_in, self.dim_out, 3, 1, 1))
+
+ cls_pred_dim = cfg.MODEL.NUM_CLASSES if cfg.RETINANET.SOFTMAX \
+ else (cfg.MODEL.NUM_CLASSES - 1)
+
+ # unpacked bbox feature and add prediction layers
+ self.bbox_regr_dim = 4 * (cfg.MODEL.NUM_CLASSES - 1) \
+ if cfg.RETINANET.CLASS_SPECIFIC_BBOX else 4
+
+ self.fpn_cls_score = nn.Conv2d(self.dim_out,
+ cls_pred_dim * self.num_anchors, 3, 1, 1)
+ self.fpn_bbox_score = nn.Conv2d(self.dim_out,
+ self.bbox_regr_dim * self.num_anchors, 3, 1, 1)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ def init_func(m):
+ if isinstance(m, nn.Conv2d):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ for child_m in self.children():
+ if isinstance(child_m, nn.ModuleList):
+ child_m.apply(init_func)
+
+ init.normal_(self.fpn_cls_score.weight, std=0.01)
+ init.constant_(self.fpn_cls_score.bias,
+ -math.log((1 - cfg.RETINANET.PRIOR_PROB) / cfg.RETINANET.PRIOR_PROB))
+
+ init.normal_(self.fpn_bbox_score.weight, std=0.01)
+ init.constant_(self.fpn_bbox_score.bias, 0)
+
+ def detectron_weight_mapping(self):
+ k_min = cfg.FPN.RPN_MIN_LEVEL
+ mapping_to_detectron = {
+ 'n_conv_fpn_cls_modules.0.weight': 'retnet_cls_conv_n0_fpn%d_w' % k_min,
+ 'n_conv_fpn_cls_modules.0.bias': 'retnet_cls_conv_n0_fpn%d_b' % k_min,
+ 'n_conv_fpn_cls_modules.1.weight': 'retnet_cls_conv_n1_fpn%d_w' % k_min,
+ 'n_conv_fpn_cls_modules.1.bias': 'retnet_cls_conv_n1_fpn%d_b' % k_min,
+ 'n_conv_fpn_cls_modules.2.weight': 'retnet_cls_conv_n2_fpn%d_w' % k_min,
+ 'n_conv_fpn_cls_modules.2.bias': 'retnet_cls_conv_n2_fpn%d_b' % k_min,
+ 'n_conv_fpn_cls_modules.3.weight': 'retnet_cls_conv_n3_fpn%d_w' % k_min,
+ 'n_conv_fpn_cls_modules.3.bias': 'retnet_cls_conv_n3_fpn%d_b' % k_min,
+
+ 'n_conv_fpn_bbox_modules.0.weight': 'retnet_bbox_conv_n0_fpn%d_w' % k_min,
+ 'n_conv_fpn_bbox_modules.0.bias': 'retnet_bbox_conv_n0_fpn%d_b' % k_min,
+ 'n_conv_fpn_bbox_modules.1.weight': 'retnet_bbox_conv_n1_fpn%d_w' % k_min,
+ 'n_conv_fpn_bbox_modules.1.bias': 'retnet_bbox_conv_n1_fpn%d_b' % k_min,
+ 'n_conv_fpn_bbox_modules.2.weight': 'retnet_bbox_conv_n2_fpn%d_w' % k_min,
+ 'n_conv_fpn_bbox_modules.2.bias': 'retnet_bbox_conv_n2_fpn%d_b' % k_min,
+ 'n_conv_fpn_bbox_modules.3.weight': 'retnet_bbox_conv_n3_fpn%d_w' % k_min,
+ 'n_conv_fpn_bbox_modules.3.bias': 'retnet_bbox_conv_n3_fpn%d_b' % k_min,
+
+ 'fpn_cls_score.weight': 'retnet_cls_pred_fpn%d_w' % k_min,
+ 'fpn_cls_score.bias': 'retnet_cls_pred_fpn%d_b' % k_min,
+ 'fpn_bbox_score.weight': 'retnet_bbox_pred_fpn%d_w' % k_min,
+ 'fpn_bbox_score.bias': 'retnet_bbox_pred_fpn%d_b' % k_min
+ }
+ return mapping_to_detectron, []
+
+ def forward(self, blobs_in):
+ k_max = cfg.FPN.RPN_MAX_LEVEL # coarsest level of pyramid
+ k_min = cfg.FPN.RPN_MIN_LEVEL # finest level of pyramid
+ assert len(blobs_in) == k_max - k_min + 1
+ bbox_feat_list = []
+ cls_score = []
+ bbox_pred = []
+
+ # ==========================================================================
+ # classification tower with logits and prob prediction
+ # ==========================================================================
+ for lvl in range(k_min, k_max + 1):
+ bl_in = blobs_in[k_max - lvl] # blobs_in is in reversed order
+ # classification tower stack convolution starts
+ for nconv in range(cfg.RETINANET.NUM_CONVS):
+ bl_out = self.n_conv_fpn_cls_modules[nconv](bl_in)
+ bl_in = F.relu(bl_out, inplace=True)
+ bl_feat = bl_in
+
+ # cls tower stack convolution ends. Add the logits layer now
+ retnet_cls_pred = self.fpn_cls_score(bl_feat)
+
+ if not self.training:
+ if cfg.RETINANET.SOFTMAX:
+ raise NotImplementedError("To be implemented")
+ else: # sigmoid
+ retnet_cls_probs = retnet_cls_pred.sigmoid()
+ cls_score.append(retnet_cls_probs)
+ else:
+ cls_score.append(retnet_cls_pred)
+
+ if cfg.RETINANET.SHARE_CLS_BBOX_TOWER:
+ bbox_feat_list.append(bl_feat)
+
+ # ==========================================================================
+ # bbox tower if not sharing features with the classification tower with
+ # logits and prob prediction
+ # ==========================================================================
+ if not cfg.RETINANET.SHARE_CLS_BBOX_TOWER:
+ for lvl in range(k_min, k_max + 1):
+ bl_in = blobs_in[k_max - lvl] # blobs_in is in reversed order
+ # classification tower stack convolution starts
+ for nconv in range(cfg.RETINANET.NUM_CONVS):
+ bl_out = self.n_conv_fpn_bbox_modules[nconv](bl_in)
+ bl_in = F.relu(bl_out, inplace=True)
+ # Add octave scales and aspect ratio
+ # At least 1 convolution for dealing different aspect ratios
+ bl_feat = bl_in
+ bbox_feat_list.append(bl_feat)
+
+ # Depending on the features [shared/separate] for bbox, add prediction layer
+ for i, lvl in enumerate(range(k_min, k_max + 1)):
+ bl_feat = bbox_feat_list[i]
+ retnet_bbox_pred = self.fpn_bbox_score(bl_feat)
+ bbox_pred.append(retnet_bbox_pred)
+
+ return cls_score, bbox_pred
+
+
+def add_fpn_retinanet_losses(cls_score, bbox_pred, **kwargs):
+ k_max = cfg.FPN.RPN_MAX_LEVEL # coarsest level of pyramid
+ k_min = cfg.FPN.RPN_MIN_LEVEL # finest level of pyramid
+
+ losses_cls = []
+ losses_bbox = []
+ for i, lvl in enumerate(range(k_min, k_max + 1)):
+ slvl = str(lvl)
+ h, w = cls_score[i].shape[2:]
+ retnet_cls_labels_fpn = kwargs['retnet_cls_labels_fpn' +
+ slvl][:, :, :h, :w]
+ retnet_bbox_targets_fpn = kwargs['retnet_roi_bbox_targets_fpn' +
+ slvl][:, :, :, :h, :w]
+ retnet_bbox_inside_weights_fpn = kwargs['retnet_bbox_inside_weights_wide_fpn' +
+ slvl][:, :, :, :h, :w]
+ retnet_fg_num = kwargs['retnet_fg_num']
+
+ # ==========================================================================
+ # bbox regression loss - SelectSmoothL1Loss for multiple anchors at a location
+ # ==========================================================================
+ bbox_loss = net_utils.select_smooth_l1_loss(
+ bbox_pred[i], retnet_bbox_targets_fpn,
+ retnet_bbox_inside_weights_fpn,
+ retnet_fg_num,
+ beta=cfg.RETINANET.BBOX_REG_BETA)
+
+ # ==========================================================================
+ # cls loss - depends on softmax/sigmoid outputs
+ # ==========================================================================
+ if cfg.RETINANET.SOFTMAX:
+ raise NotImplementedError("To be implemented")
+ else:
+ cls_loss = net_utils.sigmoid_focal_loss(
+ cls_score[i], retnet_cls_labels_fpn.float(),
+ cfg.MODEL.NUM_CLASSES, retnet_fg_num, alpha=cfg.RETINANET.LOSS_ALPHA,
+ gamma=cfg.RETINANET.LOSS_GAMMA
+ )
+
+ losses_bbox.append(bbox_loss)
+ losses_cls.append(cls_loss)
+
+ return losses_cls, losses_bbox
diff --git a/lib/roi_data/minibatch.py b/lib/roi_data/minibatch.py
index 65ef9e66..62bff744 100644
--- a/lib/roi_data/minibatch.py
+++ b/lib/roi_data/minibatch.py
@@ -4,6 +4,7 @@
from core.config import cfg
import utils.blob as blob_utils
import roi_data.rpn
+import roi_data.retinanet
def get_minibatch_blob_names(is_training=True):
@@ -15,7 +16,9 @@ def get_minibatch_blob_names(is_training=True):
# RPN-only or end-to-end Faster R-CNN
blob_names += roi_data.rpn.get_rpn_blob_names(is_training=is_training)
elif cfg.RETINANET.RETINANET_ON:
- raise NotImplementedError
+ blob_names += roi_data.retinanet.get_retinanet_blob_names(
+ is_training=is_training
+ )
else:
# Fast R-CNN like models trained on precomputed proposals
blob_names += roi_data.fast_rcnn.get_fast_rcnn_blob_names(
@@ -37,7 +40,13 @@ def get_minibatch(roidb):
# RPN-only or end-to-end Faster/Mask R-CNN
valid = roi_data.rpn.add_rpn_blobs(blobs, im_scales, roidb)
elif cfg.RETINANET.RETINANET_ON:
- raise NotImplementedError
+ im_width, im_height = im_blob.shape[3], im_blob.shape[2]
+ # im_width, im_height corresponds to the network input: padded image
+ # (if needed) width and height. We pass it as input and slice the data
+ # accordingly so that we don't need to use SampleAsOp
+ valid = roi_data.retinanet.add_retinanet_blobs(
+ blobs, im_scales, roidb, im_width, im_height
+ )
else:
# Fast R-CNN like models trained on precomputed proposals
valid = roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
diff --git a/lib/roi_data/retinanet.py b/lib/roi_data/retinanet.py
new file mode 100644
index 00000000..03b329bb
--- /dev/null
+++ b/lib/roi_data/retinanet.py
@@ -0,0 +1,258 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+##############################################################################
+
+"""Compute minibatch blobs for training a RetinaNet network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import logging
+
+import utils.boxes as box_utils
+import roi_data.data_utils as data_utils
+from core.config import cfg
+
+logger = logging.getLogger(__name__)
+
+
+def get_retinanet_blob_names(is_training=True):
+ """
+ Returns blob names in the order in which they are read by the data
+ loader.
+ """
+ # im_info: (height, width, image scale)
+ blob_names = ['im_info']
+ assert cfg.FPN.FPN_ON, "RetinaNet uses FPN for dense detection"
+ # Same format as RPN blobs, but one per FPN level
+ if is_training:
+ blob_names += ['roidb', 'retnet_fg_num', 'retnet_bg_num']
+ for lvl in range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1):
+ suffix = 'fpn{}'.format(lvl)
+ blob_names += [
+ 'retnet_cls_labels_' + suffix,
+ 'retnet_roi_bbox_targets_' + suffix,
+ 'retnet_bbox_inside_weights_wide_' + suffix,
+ ]
+ return blob_names
+
+
+def add_retinanet_blobs(blobs, im_scales, roidb, image_width, image_height):
+ """Add RetinaNet blobs."""
+ # RetinaNet is applied to many feature levels, as in the FPN paper
+ k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
+ scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE
+ num_aspect_ratios = len(cfg.RETINANET.ASPECT_RATIOS)
+ aspect_ratios = cfg.RETINANET.ASPECT_RATIOS
+ anchor_scale = cfg.RETINANET.ANCHOR_SCALE
+
+ # get anchors from all levels for all scales/aspect ratios
+ foas = []
+ for lvl in range(k_min, k_max + 1):
+ stride = 2. ** lvl
+ for octave in range(scales_per_octave):
+ octave_scale = 2 ** (octave / float(scales_per_octave))
+ for idx in range(num_aspect_ratios):
+ anchor_sizes = (stride * octave_scale * anchor_scale,)
+ anchor_aspect_ratios = (aspect_ratios[idx],)
+ foa = data_utils.get_field_of_anchors(
+ stride, anchor_sizes, anchor_aspect_ratios, octave, idx)
+ foas.append(foa)
+ all_anchors = np.concatenate([f.field_of_anchors for f in foas])
+
+ blobs['retnet_fg_num'], blobs['retnet_bg_num'] = 0.0, 0.0
+ for im_i, entry in enumerate(roidb):
+ scale = im_scales[im_i]
+ im_height = np.round(entry['height'] * scale)
+ im_width = np.round(entry['width'] * scale)
+ gt_inds = np.where(
+ (entry['gt_classes'] > 0) & (entry['is_crowd'] == 0))[0]
+ assert len(gt_inds) > 0, \
+ 'Empty ground truth empty for image is not allowed. Please check.'
+
+ gt_rois = entry['boxes'][gt_inds, :] * scale
+ gt_classes = entry['gt_classes'][gt_inds]
+
+ im_info = np.array([[im_height, im_width, scale]], dtype=np.float32)
+ blobs['im_info'].append(im_info)
+
+ retinanet_blobs, fg_num, bg_num = _get_retinanet_blobs(
+ foas, all_anchors, gt_rois, gt_classes, image_width, image_height)
+ for i, foa in enumerate(foas):
+ for k, v in retinanet_blobs[i].items():
+ level = int(np.log2(foa.stride))
+ key = '{}_fpn{}'.format(k, level)
+ blobs[key].append(v)
+ blobs['retnet_fg_num'] += fg_num
+ blobs['retnet_bg_num'] += bg_num
+
+ blobs['retnet_fg_num'] = blobs['retnet_fg_num'].astype(np.float32)
+ blobs['retnet_bg_num'] = blobs['retnet_bg_num'].astype(np.float32)
+
+ N = len(roidb)
+ for k, v in blobs.items():
+ if isinstance(v, list) and len(v) > 0:
+ # compute number of anchors
+ A = int(len(v) / N)
+ # for the cls branch labels [per fpn level],
+ # we have blobs['retnet_cls_labels_fpn{}'] as a list until this step
+ # and length of this list is N x A where
+ # N = num_images, A = num_anchors for example, N = 2, A = 9
+ # Each element of the list has the shape 1 x 1 x H x W where H, W are
+ # spatial dimension of curret fpn lvl. Let a{i} denote the element
+ # corresponding to anchor i [9 anchors total] in the list.
+ # The elements in the list are in order [[a0, ..., a9], [a0, ..., a9]]
+ # however the network will make predictions like 2 x (9 * 80) x H x W
+ # so we first concatenate the elements of each image to a numpy array
+ # and then concatenate the two images to get the 2 x 9 x H x W
+
+ if k.find('retnet_cls_labels') >= 0 \
+ or k.find('retnet_roi_bbox_targets') >= 0:
+ tmp = []
+ # concat anchors within an image
+ for i in range(0, len(v), A):
+ tmp.append(np.concatenate(v[i: i + A], axis=1))
+ # concat images
+ blobs[k] = np.concatenate(tmp, axis=0)
+ else:
+ # for the bbox branch elements [per FPN level],
+ # we have the targets and the fg boxes locations
+ # in the shape: M x 4 where M is the number of fg locations in a
+ # given image at the current FPN level. For the given level,
+ # the bbox predictions will be. The elements in the list are in
+ # order [[a0, ..., a9], [a0, ..., a9]]
+ # Concatenate them to form M x 4
+ blobs[k] = np.expand_dims(np.concatenate(v, axis=0), axis=0)
+
+ valid_keys = [
+ 'has_visible_keypoints', 'boxes', 'segms', 'seg_areas', 'gt_classes',
+ 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints'
+ ]
+ minimal_roidb = [{} for _ in range(len(roidb))]
+ for i, e in enumerate(roidb):
+ for k in valid_keys:
+ if k in e:
+ minimal_roidb[i][k] = e[k]
+ # blobs['roidb'] = blob_utils.serialize(minimal_roidb)
+ blobs['roidb'] = minimal_roidb
+
+ return True
+
+
+def _get_retinanet_blobs(
+ foas, all_anchors, gt_boxes, gt_classes, im_width, im_height):
+ total_anchors = all_anchors.shape[0]
+ logger.debug('Getting mad blobs: im_height {} im_width: {}'.format(
+ im_height, im_width))
+
+ inds_inside = np.arange(all_anchors.shape[0])
+ anchors = all_anchors
+ num_inside = len(inds_inside)
+
+ logger.debug('total_anchors: {}'.format(total_anchors))
+ logger.debug('inds_inside: {}'.format(num_inside))
+ logger.debug('anchors.shape: {}'.format(anchors.shape))
+
+ # Compute anchor labels:
+ # label=1 is positive, 0 is negative, -1 is don't care (ignore)
+ labels = np.empty((num_inside,), dtype=np.float32)
+ labels.fill(-1)
+ if len(gt_boxes) > 0:
+ # Compute overlaps between the anchors and the gt boxes overlaps
+ anchor_by_gt_overlap = box_utils.bbox_overlaps(anchors, gt_boxes)
+ # Map from anchor to gt box that has highest overlap
+ anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1)
+ # For each anchor, amount of overlap with most overlapping gt box
+ anchor_to_gt_max = anchor_by_gt_overlap[
+ np.arange(num_inside), anchor_to_gt_argmax]
+
+ # Map from gt box to an anchor that has highest overlap
+ gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0)
+ # For each gt box, amount of overlap with most overlapping anchor
+ gt_to_anchor_max = anchor_by_gt_overlap[
+ gt_to_anchor_argmax, np.arange(anchor_by_gt_overlap.shape[1])]
+ # Find all anchors that share the max overlap amount
+ # (this includes many ties)
+ anchors_with_max_overlap = np.where(
+ anchor_by_gt_overlap == gt_to_anchor_max)[0]
+
+ # Fg label: for each gt use anchors with highest overlap
+ # (including ties)
+ gt_inds = anchor_to_gt_argmax[anchors_with_max_overlap]
+ labels[anchors_with_max_overlap] = gt_classes[gt_inds]
+ # Fg label: above threshold IOU
+ inds = anchor_to_gt_max >= cfg.RETINANET.POSITIVE_OVERLAP
+ gt_inds = anchor_to_gt_argmax[inds]
+ labels[inds] = gt_classes[gt_inds]
+
+ fg_inds = np.where(labels >= 1)[0]
+ bg_inds = np.where(anchor_to_gt_max < cfg.RETINANET.NEGATIVE_OVERLAP)[0]
+ labels[bg_inds] = 0
+ num_fg, num_bg = len(fg_inds), len(bg_inds)
+
+ bbox_targets = np.zeros((num_inside, 4), dtype=np.float32)
+ bbox_targets[fg_inds, :] = data_utils.compute_targets(
+ anchors[fg_inds, :], gt_boxes[anchor_to_gt_argmax[fg_inds], :])
+
+ # Bbox regression loss has the form:
+ # loss(x) = weight_outside * L(weight_inside * x)
+ # Inside weights allow us to set zero loss on an element-wise basis
+ # Bbox regression is only trained on positive examples so we set their
+ # weights to 1.0 (or otherwise if config is different) and 0 otherwise
+ bbox_inside_weights = np.zeros((num_inside, 4), dtype=np.float32)
+ bbox_inside_weights[labels >= 1, :] = (1.0, 1.0, 1.0, 1.0)
+
+ # Map up to original set of anchors
+ labels = data_utils.unmap(labels, total_anchors, inds_inside, fill=-1)
+ bbox_inside_weights = data_utils.unmap(
+ bbox_inside_weights, total_anchors, inds_inside, fill=0
+ )
+ bbox_targets = data_utils.unmap(
+ bbox_targets, total_anchors, inds_inside, fill=0)
+
+ # Split the generated labels, etc. into labels per each field of anchors
+ blobs_out = []
+ start_idx = 0
+ for i, foa in enumerate(foas):
+ H = foa.field_size
+ W = foa.field_size
+ end_idx = start_idx + H * W
+ _labels = labels[start_idx:end_idx]
+ _bbox_targets = bbox_targets[start_idx:end_idx, :]
+ _bbox_inside_weights = bbox_inside_weights[start_idx:end_idx, :]
+ start_idx = end_idx
+ # labels output with shape (1, height, width)
+ _labels = _labels.reshape((1, 1, H, W))
+ # bbox_targets output with shape (1, 4 * A, height, width)
+ _bbox_targets = _bbox_targets.reshape(
+ (1, 1, H, W, 4)).transpose(0, 1, 4, 2, 3)
+
+ # bbox_inside_weights output with shape (1, 4 * A, height, width)
+ _bbox_inside_weights = _bbox_inside_weights.reshape(
+ (1, H, W, 4)).transpose(0, 3, 1, 2)
+
+ blobs_out.append(
+ dict(
+ retnet_cls_labels=_labels.astype(np.int32),
+ retnet_roi_bbox_targets=_bbox_targets.astype(np.float32),
+ retnet_bbox_inside_weights_wide=_bbox_inside_weights
+ ))
+ out_num_fg = np.array([num_fg + 1.0], dtype=np.float32)
+ out_num_bg = (
+ np.array([num_bg + 1.0]) * (cfg.MODEL.NUM_CLASSES - 1) +
+ out_num_fg * (cfg.MODEL.NUM_CLASSES - 2))
+ return blobs_out, out_num_fg, out_num_bg
diff --git a/lib/utils/net.py b/lib/utils/net.py
index 32c5d705..3f5e7f74 100644
--- a/lib/utils/net.py
+++ b/lib/utils/net.py
@@ -24,7 +24,7 @@ def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_we
abs_in_box_diff = torch.abs(in_box_diff)
smoothL1_sign = (abs_in_box_diff < beta).detach().float()
in_loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta + \
- (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta))
+ (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta))
out_loss_box = bbox_outside_weights * in_loss_box
loss_box = out_loss_box
N = loss_box.size(0) # batch size
@@ -32,6 +32,50 @@ def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_we
return loss_box
+def select_smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, num_fg, beta=1.0):
+ bbox_pred = bbox_pred.reshape\
+ ((bbox_pred.shape[0], bbox_targets.shape[1], 4, bbox_pred.shape[2], bbox_pred.shape[3]))
+ box_diff = bbox_pred - bbox_targets
+ in_box_diff = bbox_inside_weights * box_diff
+ abs_in_box_diff = torch.abs(in_box_diff)
+ smoothL1_sign = (abs_in_box_diff < beta).detach().float()
+ loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta + \
+ (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta))
+ loss_box = loss_box.view(-1).sum(0) / num_fg.sum()
+ return loss_box
+
+
+def sigmoid_focal_loss(cls_preds, cls_targets, num_classes, num_fg, alpha=0.25, gamma=2):
+ masked_cls_preds = cls_preds.reshape((
+ cls_preds.size(0), cls_targets.size(1), num_classes - 1,
+ cls_preds.size(2), cls_preds.size(3))).permute((0, 1, 3, 4, 2)).contiguous().\
+ view(-1, num_classes-1)
+ masked_cls_targets = cls_targets.view(-1)
+
+ weights = (masked_cls_targets >= 0).float()
+ weights = weights.unsqueeze(1)
+
+ t = masked_cls_preds.data.new(
+ masked_cls_preds.size(0), num_classes).fill_(0)
+ ids = masked_cls_targets.view(-1, 1) % (num_classes)
+ t.scatter_(1, ids.long(), 1.)
+ t = t[:, 1:]
+
+ p = masked_cls_preds.sigmoid()
+ # w = alpha if t > 0 else 1-alpha
+ alpha_factor = alpha * t + (1 - alpha) * (1 - t)
+ # pt = p if t > 0 else 1-p
+ focal_weight = p * t + (1 - p) * (1 - t)
+ focal_weight = alpha_factor * (1 - focal_weight).pow(gamma)
+
+ cls_loss = focal_weight * \
+ F.binary_cross_entropy_with_logits(
+ masked_cls_preds, t, weight=weights, reduction='none')
+ cls_loss = cls_loss.sum() / num_fg.sum()
+
+ return cls_loss
+
+
def clip_gradient(model, clip_norm):
"""Computes a gradient clipping coefficient based on gradient norm."""
totalnorm = 0
@@ -62,7 +106,9 @@ def decay_learning_rate(optimizer, cur_lr, decay_rate):
if cfg.SOLVER.TYPE in ['SGD']:
if cfg.SOLVER.SCALE_MOMENTUM and cur_lr > 1e-7 and \
ratio > cfg.SOLVER.SCALE_MOMENTUM_THRESHOLD:
- _CorrectMomentum(optimizer, param_group['params'], new_lr / cur_lr)
+ _CorrectMomentum(
+ optimizer, param_group['params'], new_lr / cur_lr)
+
def update_learning_rate(optimizer, cur_lr, new_lr):
"""Update learning rate"""
@@ -119,15 +165,16 @@ def affine_grid_gen(rois, input_size, grid_size):
width = input_size[1]
zero = Variable(rois.data.new(rois.size(0), 1).zero_())
- theta = torch.cat([\
- (x2 - x1) / (width - 1),
- zero,
- (x1 + x2 - width + 1) / (width - 1),
- zero,
- (y2 - y1) / (height - 1),
- (y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3)
+ theta = torch.cat([
+ (x2 - x1) / (width - 1),
+ zero,
+ (x1 + x2 - width + 1) / (width - 1),
+ zero,
+ (y2 - y1) / (height - 1),
+ (y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3)
- grid = F.affine_grid(theta, torch.Size((rois.size(0), 1, grid_size, grid_size)))
+ grid = F.affine_grid(theta, torch.Size(
+ (rois.size(0), 1, grid_size, grid_size)))
return grid
@@ -139,7 +186,8 @@ def save_ckpt(output_dir, args, model, optimizer):
ckpt_dir = os.path.join(output_dir, 'ckpt')
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
- save_name = os.path.join(ckpt_dir, 'model_{}_{}.pth'.format(args.epoch, args.step))
+ save_name = os.path.join(
+ ckpt_dir, 'model_{}_{}.pth'.format(args.epoch, args.step))
if isinstance(model, mynn.DataParallel):
model = model.module
# TODO: (maybe) Do not save redundant shared params
diff --git a/lib/utils/training_stats.py b/lib/utils/training_stats.py
index 8cae8a18..610c4304 100644
--- a/lib/utils/training_stats.py
+++ b/lib/utils/training_stats.py
@@ -83,24 +83,24 @@ def UpdateIterStats(self, model_out, inner_iter=None):
assert loss.shape[0] == cfg.NUM_GPUS
loss = loss.mean(dim=0, keepdim=True)
total_loss += loss
- loss_data = loss.data[0]
+ loss_data = loss.data[0].item()
model_out['losses'][k] = loss
if cfg.FPN.FPN_ON:
- if k.startswith('loss_rpn_cls_'):
+ if k.startswith('loss_rpn_cls_') or k.startswith('loss_retnet_cls_'):
loss_rpn_cls_data += loss_data
- elif k.startswith('loss_rpn_bbox_'):
+ elif k.startswith('loss_rpn_bbox_') or k.startswith('loss_retnet_bbox_'):
loss_rpn_bbox_data += loss_data
self.smoothed_losses[k].AddValue(loss_data)
model_out['total_loss'] = total_loss # Add the total loss for back propagation
- self.smoothed_total_loss.AddValue(total_loss.data[0])
+ self.smoothed_total_loss.AddValue(total_loss.data[0].item())
if cfg.FPN.FPN_ON:
self.smoothed_losses['loss_rpn_cls'].AddValue(loss_rpn_cls_data)
self.smoothed_losses['loss_rpn_bbox'].AddValue(loss_rpn_bbox_data)
for k, metric in model_out['metrics'].items():
metric = metric.mean(dim=0, keepdim=True)
- self.smoothed_metrics[k].AddValue(metric.data[0])
+ self.smoothed_metrics[k].AddValue(metric.data[0].item())
def _UpdateIterStats_inner(self, model_out, inner_iter):
"""Update tracked iteration statistics for the case of iter_size > 1"""
@@ -125,13 +125,13 @@ def _UpdateIterStats_inner(self, model_out, inner_iter):
assert loss.shape[0] == cfg.NUM_GPUS
loss = loss.mean(dim=0, keepdim=True)
total_loss += loss
- loss_data = loss.data[0]
+ loss_data = loss.data[0].item()
model_out['losses'][k] = loss
if cfg.FPN.FPN_ON:
- if k.startswith('loss_rpn_cls_'):
+ if k.startswith('loss_rpn_cls_') or k.startswith('loss_retnet_cls_'):
loss_rpn_cls_data += loss_data
- elif k.startswith('loss_rpn_bbox_'):
+ elif k.startswith('loss_rpn_bbox_') or k.startswith('loss_retnet_bbox_'):
loss_rpn_bbox_data += loss_data
self.inner_losses[k].append(loss_data)
@@ -140,7 +140,7 @@ def _UpdateIterStats_inner(self, model_out, inner_iter):
self.smoothed_losses[k].AddValue(loss_data)
model_out['total_loss'] = total_loss # Add the total loss for back propagation
- total_loss_data = total_loss.data[0]
+ total_loss_data = total_loss.data[0].item()
self.inner_total_loss.append(total_loss_data)
if cfg.FPN.FPN_ON:
self.inner_loss_rpn_cls.append(loss_rpn_cls_data)
@@ -156,7 +156,7 @@ def _UpdateIterStats_inner(self, model_out, inner_iter):
for k, metric in model_out['metrics'].items():
metric = metric.mean(dim=0, keepdim=True)
- metric_data = metric.data[0]
+ metric_data = metric.data[0].item()
self.inner_metrics[k].append(metric_data)
if inner_iter == (self.misc_args.iter_size - 1):
metric_data = self._mean_and_reset_inner_list('inner_metrics', k)