From 762e3b53ca30e4cb0302a40330a609216236e541 Mon Sep 17 00:00:00 2001 From: Sun Jiahao <72679458+sunjiahao1999@users.noreply.github.com> Date: Thu, 28 Dec 2023 21:34:59 +0800 Subject: [PATCH] [Feature] Support DSVT training (#2738) Co-authored-by: JingweiZhang12 Co-authored-by: sjh --- .../models/dense_heads/centerpoint_head.py | 4 +- mmdet3d/models/necks/second_fpn.py | 18 +- mmdet3d/structures/bbox_3d/base_box3d.py | 13 +- projects/DSVT/README.md | 18 +- ...ecfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py | 146 +++++-- projects/DSVT/dsvt/__init__.py | 5 +- projects/DSVT/dsvt/disable_aug_hook.py | 69 ++++ projects/DSVT/dsvt/dsvt.py | 6 +- projects/DSVT/dsvt/dsvt_head.py | 391 +++++++++++++++++- projects/DSVT/dsvt/dynamic_pillar_vfe.py | 2 + projects/DSVT/dsvt/res_second.py | 19 +- projects/DSVT/dsvt/transforms_3d.py | 116 ++++++ projects/DSVT/dsvt/utils.py | 144 ++++++- tools/train.py | 10 + 14 files changed, 875 insertions(+), 86 deletions(-) create mode 100644 projects/DSVT/dsvt/disable_aug_hook.py create mode 100644 projects/DSVT/dsvt/transforms_3d.py diff --git a/mmdet3d/models/dense_heads/centerpoint_head.py b/mmdet3d/models/dense_heads/centerpoint_head.py index 12ba84234e..c3fc187964 100644 --- a/mmdet3d/models/dense_heads/centerpoint_head.py +++ b/mmdet3d/models/dense_heads/centerpoint_head.py @@ -101,7 +101,7 @@ def forward(self, x): Returns: dict[str: torch.Tensor]: contains the following keys: - -reg (torch.Tensor): 2D regression value with the + -reg (torch.Tensor): 2D regression value with the shape of [B, 2, H, W]. -height (torch.Tensor): Height value with the shape of [B, 1, H, W]. @@ -217,7 +217,7 @@ def forward(self, x): Returns: dict[str: torch.Tensor]: contains the following keys: - -reg (torch.Tensor): 2D regression value with the + -reg (torch.Tensor): 2D regression value with the shape of [B, 2, H, W]. -height (torch.Tensor): Height value with the shape of [B, 1, H, W]. diff --git a/mmdet3d/models/necks/second_fpn.py b/mmdet3d/models/necks/second_fpn.py index 90e57ec05c..d4dc590c15 100644 --- a/mmdet3d/models/necks/second_fpn.py +++ b/mmdet3d/models/necks/second_fpn.py @@ -21,6 +21,10 @@ class SECONDFPN(BaseModule): upsample_cfg (dict): Config dict of upsample layers. conv_cfg (dict): Config dict of conv layers. use_conv_for_no_stride (bool): Whether to use conv when stride is 1. + init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], + optional): Initialization config dict. Defaults to + [dict(type='Kaiming', layer='ConvTranspose2d'), + dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)]. """ def __init__(self, @@ -31,7 +35,13 @@ def __init__(self, upsample_cfg=dict(type='deconv', bias=False), conv_cfg=dict(type='Conv2d', bias=False), use_conv_for_no_stride=False, - init_cfg=None): + init_cfg=[ + dict(type='Kaiming', layer='ConvTranspose2d'), + dict( + type='Constant', + layer='NaiveSyncBatchNorm2d', + val=1.0) + ]): # if for GroupNorm, # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) super(SECONDFPN, self).__init__(init_cfg=init_cfg) @@ -64,12 +74,6 @@ def __init__(self, deblocks.append(deblock) self.deblocks = nn.ModuleList(deblocks) - if init_cfg is None: - self.init_cfg = [ - dict(type='Kaiming', layer='ConvTranspose2d'), - dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0) - ] - def forward(self, x): """Forward function. diff --git a/mmdet3d/structures/bbox_3d/base_box3d.py b/mmdet3d/structures/bbox_3d/base_box3d.py index 50b092c06e..7fb703c731 100644 --- a/mmdet3d/structures/bbox_3d/base_box3d.py +++ b/mmdet3d/structures/bbox_3d/base_box3d.py @@ -275,12 +275,13 @@ def in_range_3d( Tensor: A binary vector indicating whether each point is inside the reference range. """ - in_range_flags = ((self.tensor[:, 0] > box_range[0]) - & (self.tensor[:, 1] > box_range[1]) - & (self.tensor[:, 2] > box_range[2]) - & (self.tensor[:, 0] < box_range[3]) - & (self.tensor[:, 1] < box_range[4]) - & (self.tensor[:, 2] < box_range[5])) + gravity_center = self.gravity_center + in_range_flags = ((gravity_center[:, 0] > box_range[0]) + & (gravity_center[:, 1] > box_range[1]) + & (gravity_center[:, 2] > box_range[2]) + & (gravity_center[:, 0] < box_range[3]) + & (gravity_center[:, 1] < box_range[4]) + & (gravity_center[:, 2] < box_range[5])) return in_range_flags @abstractmethod diff --git a/projects/DSVT/README.md b/projects/DSVT/README.md index a4b45b570d..d60e49abe5 100644 --- a/projects/DSVT/README.md +++ b/projects/DSVT/README.md @@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1- ### Training commands -The support of training DSVT is on the way. +In MMDetection3D's root directory, run the following command to test the model: + +```bash +tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch +``` ## Results and models ### Waymo -| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | -| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: | -| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | | | 75.2 | 72.2 | 68.9 | 66.1 | | +| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | +| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | 75.5 | 72.4 | 69.2 | 66.3 | \[log\](\ dict: """Forward function for CenterPoint. @@ -66,7 +98,298 @@ def loss(self, pts_feats: List[Tensor], Returns: dict: Losses of each branch. """ - pass + outs = self(pts_feats) + batch_gt_instance_3d = [] + for data_sample in batch_data_samples: + batch_gt_instance_3d.append(data_sample.gt_instances_3d) + losses = self.loss_by_feat(outs, batch_gt_instance_3d) + return losses + + def _decode_all_preds(self, + pred_dict, + point_cloud_range=None, + voxel_size=None): + batch_size, _, H, W = pred_dict['reg'].shape + + batch_center = pred_dict['reg'].permute(0, 2, 3, 1).contiguous().view( + batch_size, H * W, 2) # (B, H, W, 2) + batch_center_z = pred_dict['height'].permute( + 0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1) + batch_dim = pred_dict['dim'].exp().permute( + 0, 2, 3, 1).contiguous().view(batch_size, H * W, 3) # (B, H, W, 3) + batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1).permute( + 0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1) + batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1).permute( + 0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1) + batch_vel = pred_dict['vel'].permute(0, 2, 3, 1).contiguous().view( + batch_size, H * W, 2) if 'vel' in pred_dict.keys() else None + + angle = torch.atan2(batch_rot_sin, batch_rot_cos) # (B, H*W, 1) + + ys, xs = torch.meshgrid([ + torch.arange( + 0, H, device=batch_center.device, dtype=batch_center.dtype), + torch.arange( + 0, W, device=batch_center.device, dtype=batch_center.dtype) + ]) + ys = ys.view(1, H, W).repeat(batch_size, 1, 1) + xs = xs.view(1, H, W).repeat(batch_size, 1, 1) + xs = xs.view(batch_size, -1, 1) + batch_center[:, :, 0:1] + ys = ys.view(batch_size, -1, 1) + batch_center[:, :, 1:2] + + xs = xs * voxel_size[0] + point_cloud_range[0] + ys = ys * voxel_size[1] + point_cloud_range[1] + + box_part_list = [xs, ys, batch_center_z, batch_dim, angle] + if batch_vel is not None: + box_part_list.append(batch_vel) + + box_preds = torch.cat((box_part_list), + dim=-1).view(batch_size, H, W, -1) + + return box_preds + + def _transpose_and_gather_feat(self, feat, ind): + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = self._gather_feat(feat, ind) + return feat + + def calc_iou_loss(self, iou_preds, batch_box_preds, mask, ind, gt_boxes): + """ + Args: + iou_preds: (batch x 1 x h x w) + batch_box_preds: (batch x (7 or 9) x h x w) + mask: (batch x max_objects) + ind: (batch x max_objects) + gt_boxes: List of batch groundtruth boxes. + + Returns: + Tensor: IoU Loss. + """ + if mask.sum() == 0: + return iou_preds.new_zeros((1)) + + mask = mask.bool() + selected_iou_preds = self._transpose_and_gather_feat(iou_preds, + ind)[mask] + + selected_box_preds = self._transpose_and_gather_feat( + batch_box_preds, ind)[mask] + gt_boxes = torch.cat(gt_boxes) + assert gt_boxes.size(0) == selected_box_preds.size(0) + iou_target = boxes_iou3d(selected_box_preds[:, 0:7], gt_boxes[:, 0:7]) + iou_target = torch.diag(iou_target).view(-1) + iou_target = iou_target * 2 - 1 # [0, 1] ==> [-1, 1] + + loss = self.loss_iou(selected_iou_preds.view(-1), iou_target) + loss = loss / torch.clamp(mask.sum(), min=1e-4) + return loss + + def calc_iou_reg_loss(self, batch_box_preds, mask, ind, gt_boxes): + if mask.sum() == 0: + return batch_box_preds.new_zeros((1)) + + mask = mask.bool() + + selected_box_preds = self._transpose_and_gather_feat( + batch_box_preds, ind)[mask] + gt_boxes = torch.cat(gt_boxes) + assert gt_boxes.size(0) == selected_box_preds.size(0) + loss = self.loss_iou_reg(selected_box_preds[:, 0:7], gt_boxes[:, 0:7]) + + return loss + + def get_targets( + self, + batch_gt_instances_3d: List[InstanceData], + ) -> Tuple[List[Tensor]]: + """Generate targets. + + How each output is transformed: + + Each nested list is transposed so that all same-index elements in + each sub-list (1, ..., N) become the new sub-lists. + [ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ] + ==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ] + + The new transposed nested list is converted into a list of N + tensors generated by concatenating tensors in the new sub-lists. + [ tensor0, tensor1, tensor2, ... ] + + Args: + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and\ + ``labels_3d`` attributes. + + Returns: + Returns: + tuple[list[torch.Tensor]]: Tuple of target including + the following results in order. + + - list[torch.Tensor]: Heatmap scores. + - list[torch.Tensor]: Ground truth boxes. + - list[torch.Tensor]: Indexes indicating the + position of the valid boxes. + - list[torch.Tensor]: Masks indicating which + boxes are valid. + """ + heatmaps, anno_boxes, inds, masks, task_gt_bboxes = multi_apply( + self.get_targets_single, batch_gt_instances_3d) + # Transpose heatmaps + heatmaps = list(map(list, zip(*heatmaps))) + heatmaps = [torch.stack(hms_) for hms_ in heatmaps] + # Transpose anno_boxes + anno_boxes = list(map(list, zip(*anno_boxes))) + anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] + # Transpose inds + inds = list(map(list, zip(*inds))) + inds = [torch.stack(inds_) for inds_ in inds] + # Transpose masks + masks = list(map(list, zip(*masks))) + masks = [torch.stack(masks_) for masks_ in masks] + # Transpose task_gt_bboxes + task_gt_bboxes = list(map(list, zip(*task_gt_bboxes))) + return heatmaps, anno_boxes, inds, masks, task_gt_bboxes + + def get_targets_single(self, + gt_instances_3d: InstanceData) -> Tuple[Tensor]: + """Generate training targets for a single sample. + + Args: + gt_instances_3d (:obj:`InstanceData`): Gt_instances_3d of + single data sample. It usually includes + ``bboxes_3d`` and ``labels_3d`` attributes. + + Returns: + tuple[list[torch.Tensor]]: Tuple of target including + the following results in order. + + - list[torch.Tensor]: Heatmap scores. + - list[torch.Tensor]: Ground truth boxes. + - list[torch.Tensor]: Indexes indicating the position + of the valid boxes. + - list[torch.Tensor]: Masks indicating which boxes + are valid. + """ + gt_labels_3d = gt_instances_3d.labels_3d + gt_bboxes_3d = gt_instances_3d.bboxes_3d + device = gt_labels_3d.device + gt_bboxes_3d = torch.cat( + (gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), + dim=1).to(device) + max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg'] + grid_size = torch.tensor(self.train_cfg['grid_size']).to(device) + pc_range = torch.tensor(self.train_cfg['point_cloud_range']) + voxel_size = torch.tensor(self.train_cfg['voxel_size']) + + feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor'] + + # reorganize the gt_dict by tasks + task_masks = [] + flag = 0 + for class_name in self.class_names: + task_masks.append([ + torch.where(gt_labels_3d == class_name.index(i) + flag) + for i in class_name + ]) + flag += len(class_name) + + task_boxes = [] + task_classes = [] + flag2 = 0 + for idx, mask in enumerate(task_masks): + task_box = [] + task_class = [] + for m in mask: + task_box.append(gt_bboxes_3d[m]) + # 0 is background for each task, so we need to add 1 here. + task_class.append(gt_labels_3d[m] + 1 - flag2) + task_boxes.append(torch.cat(task_box, axis=0).to(device)) + task_classes.append(torch.cat(task_class).long().to(device)) + flag2 += len(mask) + draw_gaussian = draw_heatmap_gaussian + heatmaps, anno_boxes, inds, masks = [], [], [], [] + + for idx, task_head in enumerate(self.task_heads): + heatmap = gt_bboxes_3d.new_zeros( + (len(self.class_names[idx]), feature_map_size[1], + feature_map_size[0])) + + anno_box = gt_bboxes_3d.new_zeros((max_objs, 8), + dtype=torch.float32) + + ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64) + mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8) + + num_objs = min(task_boxes[idx].shape[0], max_objs) + + for k in range(num_objs): + cls_id = task_classes[idx][k] - 1 + + length = task_boxes[idx][k][3] + width = task_boxes[idx][k][4] + length = length / voxel_size[0] / self.train_cfg[ + 'out_size_factor'] + width = width / voxel_size[1] / self.train_cfg[ + 'out_size_factor'] + + if width > 0 and length > 0: + radius = gaussian_radius( + (width, length), + min_overlap=self.train_cfg['gaussian_overlap']) + radius = max(self.train_cfg['min_radius'], int(radius)) + + # be really careful for the coordinate system of + # your box annotation. + x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][ + 1], task_boxes[idx][k][2] + + coor_x = ( + x - pc_range[0] + ) / voxel_size[0] / self.train_cfg['out_size_factor'] + coor_y = ( + y - pc_range[1] + ) / voxel_size[1] / self.train_cfg['out_size_factor'] + + center = torch.tensor([coor_x, coor_y], + dtype=torch.float32, + device=device) + center_int = center.to(torch.int32) + + # throw out not in range objects to avoid out of array + # area when creating the heatmap + if not (0 <= center_int[0] < feature_map_size[0] + and 0 <= center_int[1] < feature_map_size[1]): + continue + + draw_gaussian(heatmap[cls_id], center_int, radius) + + new_idx = k + x, y = center_int[0], center_int[1] + + assert (y * feature_map_size[0] + x < + feature_map_size[0] * feature_map_size[1]) + + ind[new_idx] = y * feature_map_size[0] + x + mask[new_idx] = 1 + # TODO: support other outdoor dataset + rot = task_boxes[idx][k][6] + box_dim = task_boxes[idx][k][3:6] + if self.norm_bbox: + box_dim = box_dim.log() + anno_box[new_idx] = torch.cat([ + center - torch.tensor([x, y], device=device), + z.unsqueeze(0), box_dim, + torch.cos(rot).unsqueeze(0), + torch.sin(rot).unsqueeze(0) + ]) + + heatmaps.append(heatmap) + anno_boxes.append(anno_box) + masks.append(mask) + inds.append(ind) + return heatmaps, anno_boxes, inds, masks, task_boxes def loss_by_feat(self, preds_dicts: Tuple[List[dict]], batch_gt_instances_3d: List[InstanceData], *args, @@ -79,13 +402,72 @@ def loss_by_feat(self, preds_dicts: Tuple[List[dict]], tasks head, and the internal list indicate different FPN level. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of - gt_instances. It usually includes ``bboxes_3d`` and\ + gt_instances_3d. It usually includes ``bboxes_3d`` and ``labels_3d`` attributes. Returns: dict[str,torch.Tensor]: Loss of heatmap and bbox of each task. """ - pass + heatmaps, anno_boxes, inds, masks, task_gt_bboxes = self.get_targets( + batch_gt_instances_3d) + loss_dict = dict() + for task_id, preds_dict in enumerate(preds_dicts): + # heatmap focal loss + preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap']) + num_pos = heatmaps[task_id].eq(1).float().sum().item() + loss_heatmap = self.loss_cls( + preds_dict[0]['heatmap'], + heatmaps[task_id], + avg_factor=max(num_pos, 1)) + target_box = anno_boxes[task_id] + # reconstruct the anno_box from multiple reg heads + preds_dict[0]['anno_box'] = torch.cat( + (preds_dict[0]['reg'], preds_dict[0]['height'], + preds_dict[0]['dim'], preds_dict[0]['rot']), + dim=1) + + # Regression loss for dimension, offset, height, rotation + ind = inds[task_id] + num = masks[task_id].float().sum() + pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous() + pred = pred.view(pred.size(0), -1, pred.size(3)) + pred = self._gather_feat(pred, ind) + mask = masks[task_id].unsqueeze(2).expand_as(target_box).float() + isnotnan = (~torch.isnan(target_box)).float() + mask *= isnotnan + + code_weights = self.train_cfg.get('code_weights', None) + bbox_weights = mask * mask.new_tensor(code_weights) + loss_bbox = self.loss_bbox( + pred, target_box, bbox_weights, avg_factor=(num + 1e-4)) + loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap + loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox + + if 'iou' in preds_dict[0]: + batch_box_preds = self._decode_all_preds( + pred_dict=preds_dict[0], + point_cloud_range=self.train_cfg['point_cloud_range'], + voxel_size=self.train_cfg['voxel_size'] + ) # (B, H, W, 7 or 9) + + batch_box_preds_for_iou = batch_box_preds.permute( + 0, 3, 1, 2) # (B, 7 or 9, H, W) + loss_dict[f'task{task_id}.loss_iou'] = self.calc_iou_loss( + iou_preds=preds_dict[0]['iou'], + batch_box_preds=batch_box_preds_for_iou.clone().detach(), + mask=masks[task_id], + ind=ind, + gt_boxes=task_gt_bboxes[task_id]) + + if self.loss_iou_reg is not None: + loss_dict[f'task{task_id}.loss_reg_iou'] = \ + self.calc_iou_reg_loss( + batch_box_preds=batch_box_preds_for_iou, + mask=masks[task_id], + ind=ind, + gt_boxes=task_gt_bboxes[task_id]) + + return loss_dict def predict(self, pts_feats: Tuple[torch.Tensor], @@ -158,6 +540,7 @@ def predict_by_feat(self, preds_dicts: Tuple[List[dict]], else: batch_dim = preds_dict[0]['dim'] + # It's different from CenterHead batch_rotc = preds_dict[0]['rot'][:, 0].unsqueeze(1) batch_rots = preds_dict[0]['rot'][:, 1].unsqueeze(1) batch_iou = (preds_dict[0]['iou'] + diff --git a/projects/DSVT/dsvt/dynamic_pillar_vfe.py b/projects/DSVT/dsvt/dynamic_pillar_vfe.py index 97c75aaf00..3fc5266c56 100644 --- a/projects/DSVT/dsvt/dynamic_pillar_vfe.py +++ b/projects/DSVT/dsvt/dynamic_pillar_vfe.py @@ -1,4 +1,5 @@ # modified from https://github.com/Haiyang-W/DSVT +import numpy as np import torch import torch.nn as nn import torch_scatter @@ -76,6 +77,7 @@ def __init__(self, with_distance, use_absolute_xyz, use_norm, num_filters, self.voxel_x = voxel_size[0] self.voxel_y = voxel_size[1] self.voxel_z = voxel_size[2] + point_cloud_range = np.array(point_cloud_range).astype(np.float32) self.x_offset = self.voxel_x / 2 + point_cloud_range[0] self.y_offset = self.voxel_y / 2 + point_cloud_range[1] self.z_offset = self.voxel_z / 2 + point_cloud_range[2] diff --git a/projects/DSVT/dsvt/res_second.py b/projects/DSVT/dsvt/res_second.py index e1ddc1be6c..f8775e34e8 100644 --- a/projects/DSVT/dsvt/res_second.py +++ b/projects/DSVT/dsvt/res_second.py @@ -1,7 +1,5 @@ # modified from https://github.com/Haiyang-W/DSVT - -import warnings -from typing import Optional, Sequence, Tuple +from typing import Sequence, Tuple from mmengine.model import BaseModule from torch import Tensor @@ -78,8 +76,8 @@ class ResSECOND(BaseModule): out_channels (list[int]): Output channels for multi-scale feature maps. blocks_nums (list[int]): Number of blocks in each stage. layer_strides (list[int]): Strides of each stage. - norm_cfg (dict): Config dict of normalization layers. - conv_cfg (dict): Config dict of convolutional layers. + init_cfg (dict, optional): Config for weight initialization. + Defaults to None. """ def __init__(self, @@ -87,8 +85,7 @@ def __init__(self, out_channels: Sequence[int] = [128, 128, 256], blocks_nums: Sequence[int] = [1, 2, 2], layer_strides: Sequence[int] = [2, 2, 2], - init_cfg: OptMultiConfig = None, - pretrained: Optional[str] = None) -> None: + init_cfg: OptMultiConfig = None) -> None: super(ResSECOND, self).__init__(init_cfg=init_cfg) assert len(layer_strides) == len(blocks_nums) assert len(out_channels) == len(blocks_nums) @@ -108,14 +105,6 @@ def __init__(self, BasicResBlock(out_channels[i], out_channels[i])) blocks.append(nn.Sequential(*cur_layers)) self.blocks = nn.Sequential(*blocks) - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' - if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) - else: - self.init_cfg = dict(type='Kaiming', layer='Conv2d') def forward(self, x: Tensor) -> Tuple[Tensor, ...]: """Forward function. diff --git a/projects/DSVT/dsvt/transforms_3d.py b/projects/DSVT/dsvt/transforms_3d.py new file mode 100644 index 0000000000..ff0c9a2314 --- /dev/null +++ b/projects/DSVT/dsvt/transforms_3d.py @@ -0,0 +1,116 @@ +from typing import List + +import numpy as np +from mmcv import BaseTransform + +from mmdet3d.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class ObjectRangeFilter3D(BaseTransform): + """Filter objects by the range. It differs from `ObjectRangeFilter` by + using `in_range_3d` instead of `in_range_bev`. + + Required Keys: + + - gt_bboxes_3d + + Modified Keys: + + - gt_bboxes_3d + + Args: + point_cloud_range (list[float]): Point cloud range. + """ + + def __init__(self, point_cloud_range: List[float]) -> None: + self.pcd_range = np.array(point_cloud_range, dtype=np.float32) + + def transform(self, input_dict: dict) -> dict: + """Transform function to filter objects by the range. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' + keys are updated in the result dict. + """ + gt_bboxes_3d = input_dict['gt_bboxes_3d'] + gt_labels_3d = input_dict['gt_labels_3d'] + mask = gt_bboxes_3d.in_range_3d(self.pcd_range) + gt_bboxes_3d = gt_bboxes_3d[mask] + # mask is a torch tensor but gt_labels_3d is still numpy array + # using mask to index gt_labels_3d will cause bug when + # len(gt_labels_3d) == 1, where mask=1 will be interpreted + # as gt_labels_3d[1] and cause out of index error + gt_labels_3d = gt_labels_3d[mask.numpy().astype(bool)] + + # limit rad to [-pi, pi] + gt_bboxes_3d.limit_yaw(offset=0.5, period=2 * np.pi) + input_dict['gt_bboxes_3d'] = gt_bboxes_3d + input_dict['gt_labels_3d'] = gt_labels_3d + + return input_dict + + def __repr__(self) -> str: + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f'(point_cloud_range={self.pcd_range.tolist()})' + return repr_str + + +@TRANSFORMS.register_module() +class PointsRangeFilter3D(BaseTransform): + """Filter points by the range. It differs from `PointRangeFilter` by using + `in_range_bev` instead of `in_range_3d`. + + Required Keys: + + - points + - pts_instance_mask (optional) + + Modified Keys: + + - points + - pts_instance_mask (optional) + + Args: + point_cloud_range (list[float]): Point cloud range. + """ + + def __init__(self, point_cloud_range: List[float]) -> None: + self.pcd_range = np.array(point_cloud_range, dtype=np.float32) + + def transform(self, input_dict: dict) -> dict: + """Transform function to filter points by the range. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Results after filtering, 'points', 'pts_instance_mask' + and 'pts_semantic_mask' keys are updated in the result dict. + """ + points = input_dict['points'] + points_mask = points.in_range_bev(self.pcd_range[[0, 1, 3, 4]]) + clean_points = points[points_mask] + input_dict['points'] = clean_points + points_mask = points_mask.numpy() + + pts_instance_mask = input_dict.get('pts_instance_mask', None) + pts_semantic_mask = input_dict.get('pts_semantic_mask', None) + + if pts_instance_mask is not None: + input_dict['pts_instance_mask'] = pts_instance_mask[points_mask] + + if pts_semantic_mask is not None: + input_dict['pts_semantic_mask'] = pts_semantic_mask[points_mask] + + return input_dict + + def __repr__(self) -> str: + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f'(point_cloud_range={self.pcd_range.tolist()})' + return repr_str diff --git a/projects/DSVT/dsvt/utils.py b/projects/DSVT/dsvt/utils.py index 7c40383ce7..706ee04280 100644 --- a/projects/DSVT/dsvt/utils.py +++ b/projects/DSVT/dsvt/utils.py @@ -3,10 +3,11 @@ import numpy as np import torch import torch.nn as nn +from mmdet.models.losses.utils import weighted_loss from torch import Tensor from mmdet3d.models.task_modules import CenterPointBBoxCoder -from mmdet3d.registry import TASK_UTILS +from mmdet3d.registry import MODELS, TASK_UTILS from .ops.ingroup_inds.ingroup_inds_op import ingroup_inds get_inner_win_inds_cuda = ingroup_inds @@ -266,7 +267,7 @@ def decode(self, thresh_mask = final_scores > self.score_threshold if self.post_center_range is not None: - self.post_center_range = torch.tensor( + self.post_center_range = torch.as_tensor( self.post_center_range, device=heat.device) mask = (final_box_preds[..., :3] >= self.post_center_range[:3]).all(2) @@ -298,3 +299,142 @@ def decode(self, 'support post_center_range is not None for now!') return predictions_dicts + + +def center_to_corner2d(center, dim): + corners_norm = torch.tensor( + [[-0.5, -0.5], [-0.5, 0.5], [0.5, 0.5], [0.5, -0.5]], + device=dim.device).type_as(center) # (4, 2) + corners = dim.view([-1, 1, 2]) * corners_norm.view([1, 4, 2]) # (N, 4, 2) + corners = corners + center.view(-1, 1, 2) + return corners + + +@weighted_loss +def diou3d_loss(pred_boxes, gt_boxes, eps: float = 1e-7): + """ + modified from https://github.com/agent-sgs/PillarNet/blob/master/det3d/core/utils/center_utils.py # noqa + Args: + pred_boxes (N, 7): + gt_boxes (N, 7): + + Returns: + Tensor: Distance-IoU Loss. + """ + assert pred_boxes.shape[0] == gt_boxes.shape[0] + + qcorners = center_to_corner2d(pred_boxes[:, :2], + pred_boxes[:, 3:5]) # (N, 4, 2) + gcorners = center_to_corner2d(gt_boxes[:, :2], gt_boxes[:, + 3:5]) # (N, 4, 2) + + inter_max_xy = torch.minimum(qcorners[:, 2], gcorners[:, 2]) + inter_min_xy = torch.maximum(qcorners[:, 0], gcorners[:, 0]) + out_max_xy = torch.maximum(qcorners[:, 2], gcorners[:, 2]) + out_min_xy = torch.minimum(qcorners[:, 0], gcorners[:, 0]) + + # calculate area + volume_pred_boxes = pred_boxes[:, 3] * pred_boxes[:, 4] * pred_boxes[:, 5] + volume_gt_boxes = gt_boxes[:, 3] * gt_boxes[:, 4] * gt_boxes[:, 5] + + inter_h = torch.minimum( + pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5], + gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5]) - torch.maximum( + pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5], + gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5]) + inter_h = torch.clamp(inter_h, min=0) + + inter = torch.clamp((inter_max_xy - inter_min_xy), min=0) + volume_inter = inter[:, 0] * inter[:, 1] * inter_h + volume_union = volume_gt_boxes + volume_pred_boxes - volume_inter + eps + + # boxes_iou3d_gpu(pred_boxes, gt_boxes) + inter_diag = torch.pow(gt_boxes[:, 0:3] - pred_boxes[:, 0:3], 2).sum(-1) + + outer_h = torch.maximum( + gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5], + pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5]) - torch.minimum( + gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5], + pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5]) + outer_h = torch.clamp(outer_h, min=0) + outer = torch.clamp((out_max_xy - out_min_xy), min=0) + outer_diag = outer[:, 0]**2 + outer[:, 1]**2 + outer_h**2 + eps + + dious = volume_inter / volume_union - inter_diag / outer_diag + dious = torch.clamp(dious, min=-1.0, max=1.0) + + loss = 1 - dious + + return loss + + +@MODELS.register_module() +class DIoU3DLoss(nn.Module): + r"""3D bboxes Implementation of `Distance-IoU Loss: Faster and Better + Learning for Bounding Box Regression `_. + + Code is modified from https://github.com/Zzh-tju/DIoU. + + Args: + eps (float): Epsilon to avoid log(0). Defaults to 1e-6. + reduction (str): Options are "none", "mean" and "sum". + Defaults to "mean". + loss_weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__(self, + eps: float = 1e-6, + reduction: str = 'mean', + loss_weight: float = 1.0) -> None: + super().__init__() + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + avg_factor: Optional[int] = None, + reduction_override: Optional[str] = None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # giou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * diou3d_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss diff --git a/tools/train.py b/tools/train.py index b2ced54b05..6b9c3b0842 100644 --- a/tools/train.py +++ b/tools/train.py @@ -21,6 +21,12 @@ def parse_args(): action='store_true', default=False, help='enable automatic-mixed-precision training') + parser.add_argument( + '--sync_bn', + choices=['none', 'torch', 'mmcv'], + default='none', + help='convert all BatchNorm layers in the model to SyncBatchNorm ' + '(SyncBN) or mmcv.ops.sync_bn.SyncBatchNorm (MMSyncBN) layers.') parser.add_argument( '--auto-scale-lr', action='store_true', @@ -98,6 +104,10 @@ def main(): cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.loss_scale = 'dynamic' + # convert BatchNorm layers + if args.sync_bn != 'none': + cfg.sync_bn = args.sync_bn + # enable automatically scaling LR if args.auto_scale_lr: if 'auto_scale_lr' in cfg and \