Skip to content

Commit

Permalink
merge dev-1.x
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiahao1999 committed Dec 28, 2023
2 parents 43963f8 + 762e3b5 commit 3ea3f9f
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 25 deletions.
18 changes: 13 additions & 5 deletions projects/DSVT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-

### Training commands

The support of training DSVT is on the way.
In MMDetection3D's root directory, run the following command to test the model:

```bash
tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch
```

## Results and models

### Waymo

| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | 75.5 | 72.4 | 69.2 | 66.3 | \[log\](\<https://download.openmmlab.com/mmdetection3d/v1.1.0_models/dsvt/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class_20230917_102130.log) |

**Note**:

- `ResSECOND` denotes the base block in SECOND has residual layers.

**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.
- Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), so we only provide the training logs as shown above.

## Citation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0.5]),
dict(type='DSVTPointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='DSVTObjectRangeFilter', point_cloud_range=point_cloud_range),
# dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
Expand All @@ -176,7 +175,7 @@
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
dict(type='DSVTPointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(
type='Pack3DDetInputs',
keys=['points'],
Expand All @@ -188,12 +187,11 @@
batch_size=1,
num_workers=4,
persistent_workers=True,
# sampler=dict(type='DefaultSampler', shuffle=False),
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_train.pkl',
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
Expand Down Expand Up @@ -227,7 +225,7 @@
val_evaluator = dict(
type='WaymoMetric',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
result_prefix='./dsvt_pred.bin')
result_prefix='./dsvt_pred')
test_evaluator = val_evaluator

# vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
Expand Down Expand Up @@ -281,6 +279,51 @@
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=12, val_interval=1)

# schedules
lr = 1e-5
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=10, norm_type=2))
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=1.2,
eta_min=lr * 100,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=10.8,
eta_min=lr * 1e-4,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
dict(
type='CosineAnnealingMomentum',
T_max=1.2,
eta_min=0.85,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=10.8,
eta_min=0.95,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True)
]

# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=12, val_interval=1)

# runtime settings
val_cfg = dict()
test_cfg = dict()
Expand Down
4 changes: 2 additions & 2 deletions projects/DSVT/dsvt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from .dynamic_pillar_vfe import DynamicPillarVFE3D
from .map2bev import PointPillarsScatter3D
from .res_second import ResSECOND
from .transforms_3d import DSVTObjectRangeFilter, DSVTPointsRangeFilter
from .transforms_3d import ObjectRangeFilter3D, PointsRangeFilter3D
from .utils import DSVTBBoxCoder

__all__ = [
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder',
'DSVTObjectRangeFilter', 'DSVTPointsRangeFilter', 'DisableAugHook'
'ObjectRangeFilter3D', 'PointsRangeFilter3D', 'DisableAugHook'
]
7 changes: 4 additions & 3 deletions projects/DSVT/dsvt/dsvt_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class DSVTCenterHead(CenterHead):
"""

def __init__(self,
*args,
loss_iou=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=1),
loss_reg_iou=None,
*args,
**kwargs):
super(DSVTCenterHead, self).__init__(*args, **kwargs)
self.loss_iou = MODELS.build(loss_iou)
Expand Down Expand Up @@ -165,6 +165,7 @@ def calc_iou_loss(self, iou_preds, batch_box_preds, mask, ind, gt_boxes):
gt_boxes: List of batch groundtruth boxes.
Returns:
Tensor: IoU Loss.
"""
if mask.sum() == 0:
return iou_preds.new_zeros((1))
Expand Down Expand Up @@ -256,7 +257,7 @@ def get_targets_single(self,
"""Generate training targets for a single sample.
Args:
gt_instances_3d (:obj:`InstanceData`): Gt_instances of
gt_instances_3d (:obj:`InstanceData`): Gt_instances_3d of
single data sample. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Expand Down Expand Up @@ -401,7 +402,7 @@ def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
tasks head, and the internal list indicate different
FPN level.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
gt_instances_3d. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
Expand Down
1 change: 0 additions & 1 deletion projects/DSVT/dsvt/dynamic_pillar_vfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(self, with_distance, use_absolute_xyz, use_norm, num_filters,
self.voxel_x = voxel_size[0]
self.voxel_y = voxel_size[1]
self.voxel_z = voxel_size[2]
# TODO: remove it after 对齐精度
point_cloud_range = np.array(point_cloud_range).astype(np.float32)
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
Expand Down
2 changes: 2 additions & 0 deletions projects/DSVT/dsvt/res_second.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class ResSECOND(BaseModule):
out_channels (list[int]): Output channels for multi-scale feature maps.
blocks_nums (list[int]): Number of blocks in each stage.
layer_strides (list[int]): Strides of each stage.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""

def __init__(self,
Expand Down
4 changes: 2 additions & 2 deletions projects/DSVT/dsvt/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@TRANSFORMS.register_module()
class DSVTObjectRangeFilter(BaseTransform):
class ObjectRangeFilter3D(BaseTransform):
"""Filter objects by the range. It differs from `ObjectRangeFilter` by
using `in_range_3d` instead of `in_range_bev`.
Expand Down Expand Up @@ -61,7 +61,7 @@ def __repr__(self) -> str:


@TRANSFORMS.register_module()
class DSVTPointsRangeFilter(BaseTransform):
class PointsRangeFilter3D(BaseTransform):
"""Filter points by the range. It differs from `PointRangeFilter` by using
`in_range_bev` instead of `in_range_3d`.
Expand Down
11 changes: 6 additions & 5 deletions projects/DSVT/dsvt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,13 @@ def center_to_corner2d(center, dim):
@weighted_loss
def diou3d_loss(pred_boxes, gt_boxes, eps: float = 1e-7):
"""
https://github.com/agent-sgs/PillarNet/blob/master/det3d/core/utils/center_utils.py # noqa
modified from https://github.com/agent-sgs/PillarNet/blob/master/det3d/core/utils/center_utils.py # noqa
Args:
pred_boxes (N, 7):
gt_boxes (N, 7):
Returns:
_type_: _description_
Tensor: Distance-IoU Loss.
"""
assert pred_boxes.shape[0] == gt_boxes.shape[0]

Expand Down Expand Up @@ -371,14 +371,15 @@ def diou3d_loss(pred_boxes, gt_boxes, eps: float = 1e-7):
@MODELS.register_module()
class DIoU3DLoss(nn.Module):
r"""3D bboxes Implementation of `Distance-IoU Loss: Faster and Better
Learning for Bounding Box Regression https://arxiv.org/abs/1911.08287`_.
Learning for Bounding Box Regression <https://arxiv.org/abs/1911.08287>`_.
Code is modified from https://github.com/Zzh-tju/DIoU.
Args:
eps (float): Epsilon to avoid log(0).
eps (float): Epsilon to avoid log(0). Defaults to 1e-6.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of loss.
Defaults to "mean".
loss_weight (float): Weight of loss. Defaults to 1.0.
"""

def __init__(self,
Expand Down

0 comments on commit 3ea3f9f

Please sign in to comment.