Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Training DSVT model on Nuscenes datasets got low mAP #2879

Open
3 tasks done
YangChen1234567 opened this issue Jan 11, 2024 · 19 comments
Open
3 tasks done

[Bug] Training DSVT model on Nuscenes datasets got low mAP #2879

YangChen1234567 opened this issue Jan 11, 2024 · 19 comments

Comments

@YangChen1234567
Copy link

YangChen1234567 commented Jan 11, 2024

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

1.x branch https://github.com/open-mmlab/mmdetection3d/tree/dev-1.x

Environment

torch: 1.10.1
cuda: 11.4
spconv: 2.3.6
mmdet3d: 1.2.0
mmdet: 3.1.0
mmcv: 2.0.1
mmengine: 0.8.5

Hi~ I used your DSVT code to train model on Nuscenes datasets, but the mAP and NDS is very low, why? how should i modify the config file??? And I wonder if you will support DSVT training process on Nuscenes datasets(with transfusion head)?

Reproduces the problem - command or script

bash tools/dist_train.sh projects/DSVT/configs/dsvt_voxel03_res-second_secfpn_centerheadDSVT_8xb1-cyclic-20e_nuscenes.py 1

Reproduces the problem - error message

Here is the config file I used and evaluation results on Nuscenes datasets.

image-20240110150909966

_base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict(
    imports=['projects.DSVT.dsvt'], allow_failed_imports=False)

voxel_size = [0.3, 0.3, 8.0]
point_cloud_range = [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
class_names = [
    'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
    'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
grid_size = [360, 360, 1]
data_root = '/data/datasets/nuscenes/'
data_prefix = dict(
    pts='samples/LIDAR_TOP',
    CAM_FRONT='samples/CAM_FRONT',
    CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT',
    CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT',
    CAM_BACK='samples/CAM_BACK',
    CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT',
    CAM_BACK_LEFT='samples/CAM_BACK_LEFT',
    sweeps='sweeps/LIDAR_TOP')
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=False)
backend_args = None

model = dict(
    type='DSVT',
    data_preprocessor=dict(type='Det3DDataPreprocessor', voxel=False),
    voxel_encoder=dict(
        type='DynamicPillarVFE3D',
        with_distance=False,
        use_absolute_xyz=True,
        use_norm=True,
        num_filters=[192, 192],
        num_point_features=5,
        voxel_size=voxel_size,
        grid_size=grid_size,
        point_cloud_range=point_cloud_range),
    middle_encoder=dict(
        type='DSVTMiddleEncoder',
        input_layer=dict(
            sparse_shape=grid_size,
            downsample_stride=[],
            dim_model=[192],
            set_info=[[90, 4]],
            window_shape=[[30, 30, 1]],
            hybrid_factor=[1, 1, 1], # x, y, z
            shift_list=[[[0, 0, 0], [15, 15, 0]]],
            normalize_pos=False),
        set_info=[[90, 4]],
        dim_model=[192],
        dim_feedforward=[384],
        stage_num=1,
        nhead=[8],
        conv_out_channel=192,
        output_shape=[360, 360],
        dropout=0.,
        activation='gelu'),
    map2bev=dict(
        type='PointPillarsScatter3D',
        output_shape=grid_size,
        num_bev_feats=192),
    backbone=dict(
        type='ResSECOND',
        in_channels=192,
        out_channels=[128, 128, 256],
        blocks_nums=[1, 2, 2],
        layer_strides=[1, 2, 2]),
    neck=dict(
        type='SECONDFPN',
        in_channels=[128, 128, 256],
        out_channels=[128, 128, 128],
        upsample_strides=[1, 2, 4],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        upsample_cfg=dict(type='deconv', bias=False),
        use_conv_for_no_stride=False),
    bbox_head=dict(
        type='DSVTCenterHead',
        in_channels=sum([128, 128, 128]),
        tasks=[dict(num_class=10, class_names=class_names)],
        # tasks=[
        #     dict(num_class=1, class_names=['car']),
        #     dict(num_class=2, class_names=['truck', 'construction_vehicle']),
        #     dict(num_class=2, class_names=['bus', 'trailer']),
        #     dict(num_class=1, class_names=['barrier']),
        #     dict(num_class=2, class_names=['motorcycle', 'bicycle']),
        #     dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
        # ],
        common_heads=dict(
            reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), iou=(1, 2)),
        share_conv_channel=64,
        conv_cfg=dict(type='Conv2d'),
        norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01),
        bbox_coder=dict(
            type='DSVTBBoxCoder',
            pc_range=point_cloud_range,
            max_num=500,
            post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
            score_threshold=0.1,
            out_size_factor=1,
            voxel_size=voxel_size[:2],
            code_size=7),
        separate_head=dict(
            type='SeparateHead',
            init_bias=-2.19,
            final_kernel=3,
            norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01)),
        loss_cls=dict(
            type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
        loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
        loss_iou=dict(type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
        loss_reg_iou=dict(
            type='mmdet3d.DIoU3DLoss', reduction='mean', loss_weight=2.0),
        norm_bbox=True),
    # model training and testing settings
    train_cfg=dict(
        grid_size=grid_size,
        voxel_size=voxel_size,
        point_cloud_range=point_cloud_range,
        out_size_factor=1,
        dense_reg=1,
        gaussian_overlap=0.1,
        max_objs=500,
        min_radius=2,
        code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
    test_cfg=dict(
        max_per_img=500,
        max_pool_nms=False,
        min_radius=[4, 12, 10, 1, 0.85, 0.175],
        iou_rectifier=[[0.68, 0.71, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65]],
        pc_range=[-61.2, -61.2],
        out_size_factor=1,
        voxel_size=voxel_size[:2],
        nms_type='rotate',
        multi_class_nms=True,
        pre_max_size=[[4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]],
        post_max_size=[[500, 500, 500, 500, 500, 500, 500, 500, 500, 500]],
        nms_thr=[[0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.55, 0.55, 0.55, 0.55]]))

db_sampler = dict(
    data_root=data_root,
    info_path=data_root + 'nuscenes_dbinfos_train.pkl',
    rate=1.0,
    prepare=dict(
        filter_by_difficulty=[-1],
        filter_by_min_points=dict(
            car=5,
            truck=5,
            bus=5,
            trailer=5,
            construction_vehicle=5,
            traffic_cone=5,
            barrier=5,
            motorcycle=5,
            bicycle=5,
            pedestrian=5)),
    classes=class_names,
    sample_groups=dict(
        car=2,
        truck=3,
        construction_vehicle=7,
        bus=4,
        trailer=6,
        barrier=2,
        motorcycle=6,
        bicycle=6,
        pedestrian=2,
        traffic_cone=2),
    points_loader=dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=[0, 1, 2, 3, 4],
        backend_args=backend_args))

train_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=5,
        backend_args=backend_args),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=9,
        load_dim=5,
        use_dim=5,
        pad_empty_sweeps=True,
        remove_close=True,
        backend_args=backend_args),
    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
    dict(type='ObjectSample', db_sampler=db_sampler),
    dict(
        type='RandomFlip3D',
        sync_2d=False,
        flip_ratio_bev_horizontal=0.5,
        flip_ratio_bev_vertical=0.5),
    dict(
        type='GlobalRotScaleTrans',
        scale_ratio_range=[0.9, 1.1],
        rot_range=[-0.78539816, 0.78539816],
        translation_std=0.5),
    dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
    dict(type='ObjectRangeFilter3D', point_cloud_range=point_cloud_range),
    dict(
        type='ObjectNameFilter',
        classes=[
            'car', 'truck', 'construction_vehicle', 'bus', 'trailer',
            'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
        ]),
    dict(type='PointShuffle'),
    dict(
        type='Pack3DDetInputs',
        keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]

test_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=5,
        backend_args=backend_args),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=9,
        load_dim=5,
        use_dim=5,
        pad_empty_sweeps=True,
        remove_close=True,
        backend_args=backend_args),
    dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
    dict(
        type='Pack3DDetInputs',
        keys=['points'],
        meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]

dataset_type = 'NuScenesDataset'
train_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='mmdet_pkl/nuscenes_infos_train.pkl',
        pipeline=train_pipeline,
        metainfo=metainfo,
        modality=input_modality,
        test_mode=False,
        data_prefix=data_prefix,
        use_valid_flag=True,
        # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
        # and box_type_3d='Depth' in sunrgbd and scannet dataset.
        box_type_3d='LiDAR'))
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='mmdet_pkl/nuscenes_infos_val.pkl',
        pipeline=test_pipeline,
        metainfo=metainfo,
        modality=input_modality,
        data_prefix=data_prefix,
        test_mode=True,
        box_type_3d='LiDAR',
        backend_args=backend_args))
test_dataloader = val_dataloader

val_evaluator = dict(
    type='NuScenesMetric',
    data_root=data_root,
    ann_file=data_root + 'mmdet_pkl/nuscenes_infos_val.pkl',
    metric='bbox',
    backend_args=backend_args)
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend'),
                dict(type='TensorboardVisBackend')]
visualizer = dict(
    type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# schedules
lr = 1e-4
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
    clip_grad=dict(max_norm=10, norm_type=2))
param_scheduler = [
    # learning rate scheduler
    # During the first 8 epochs, learning rate increases from 0 to lr * 10
    # during the next 12 epochs, learning rate decreases from lr * 10 to
    # lr * 1e-4
    dict(
        type='CosineAnnealingLR',
        T_max=8,
        eta_min=lr * 10,
        begin=0,
        end=8,
        by_epoch=True,
        convert_to_iter_based=True),
    dict(
        type='CosineAnnealingLR',
        T_max=12,
        eta_min=lr * 1e-4,
        begin=8,
        end=20,
        by_epoch=True,
        convert_to_iter_based=True),
    # momentum scheduler
    # During the first 8 epochs, momentum increases from 0 to 0.85 / 0.95
    # during the next 12 epochs, momentum increases from 0.85 / 0.95 to 1
    dict(
        type='CosineAnnealingMomentum',
        T_max=8,
        eta_min=0.85 / 0.95,
        begin=0,
        end=8,
        by_epoch=True,
        convert_to_iter_based=True),
    dict(
        type='CosineAnnealingMomentum',
        T_max=12,
        eta_min=1,
        begin=8,
        end=20,
        by_epoch=True,
        convert_to_iter_based=True)
]

# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=21)

# runtime settings
val_cfg = dict()
test_cfg = dict()

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (1 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=8)

default_hooks = dict(
    logger=dict(type='LoggerHook', interval=50),
    checkpoint=dict(type='CheckpointHook', interval=1))
custom_hooks = [
    dict(
        type='DisableAugHook',
        disable_after_epoch=15,
        disable_aug_list=[
            'GlobalRotScaleTrans', 'RandomFlip3D', 'ObjectSample'
        ])
]
@YangChen1234567 YangChen1234567 changed the title [Bug] Training DSVT model in Nuscenes datasets got low mAP [Bug] Training DSVT model on Nuscenes datasets got low mAP Jan 11, 2024
@bramton
Copy link

bramton commented Jan 11, 2024

I am running into something similar when I try to use my own dataset. I suspect the issue is related to the definition of the 3D bounding boxes. In the Nuscenes this conversion is made:

gt_bboxes_3d = LiDARInstance3DBoxes(

Whereas in the Waymo dataset this conversion is made:

gt_bboxes_3d = LiDARInstance3DBoxes(ann_info['gt_bboxes_3d'])

Could you try to replace the conversion of the Nuscenes with the conversion defined in the Waymo dataset?

@YangChen1234567
Copy link
Author

yes, it's different. But I visualized the detection result with test.py, found that the orientation and location is correct. It seems that there is no relationship between low mAP and definition of gt_bboxes_3d???
Anyway, I will try it. Thank you very much~~~

I am running into something similar when I try to use my own dataset. I suspect the issue is related to the definition of the 3D bounding boxes. In the Nuscenes this conversion is made:

gt_bboxes_3d = LiDARInstance3DBoxes(

Whereas in the Waymo dataset this conversion is made:

gt_bboxes_3d = LiDARInstance3DBoxes(ann_info['gt_bboxes_3d'])

Could you try to replace the conversion of the Nuscenes with the conversion defined in the Waymo dataset?

yes, it's different. But I visualized the detection result with test.py, found that the orientation and location is correct. It seems that there is no relationship between low mAP and definition of gt_bboxes_3d???
Anyway, I will try it. Thank you very much~~~

@bramton
Copy link

bramton commented Jan 15, 2024

Please ignore my suggestion, it seems I had some mix-up regarding bbox definitions with my own code.

Any progress on the nuScenes dataset?

@YangChen1234567
Copy link
Author

Please ignore my suggestion, it seems I had some mix-up regarding bbox definitions with my own code.

Any progress on the nuScenes dataset?

no progress.....

@MarvinKlemp
Copy link

@YangChen1234567 Is it only Nuscenes?
Or did you also test it on waymo with bad results?

@YangChen1234567
Copy link
Author

@YangChen1234567 Is it only Nuscenes? Or did you also test it on waymo with bad results?

I just train DSVT on Nuscenes, didn't train or test it on waymo.....

@MarvinKlemp
Copy link

I will train it on waymo the next couple of days and report if I am able to reproduce the results.

However, it would be quite nice to have it trained on nuscenes.

@YangChen1234567
Copy link
Author

I will train it on waymo the next couple of days and report if I am able to reproduce the results.

However, it would be quite nice to have it trained on nuscenes.
ok~~~~

@YangChen1234567
Copy link
Author

@MarvinKlemp Hi~ Have you reproduced DSVT on Waymo dataset?

@MarvinKlemp
Copy link

@YangChen1234567
No im still struggling with creating the data ... the official docker doesn't work.
But I'm on it, unfortunately the process dies after like 5 or 6 hours... So it might take some time until I fixed all errors

@MarvinKlemp
Copy link

MarvinKlemp commented Feb 2, 2024

Training is running now
I will report you the results in ~6h

@MarvinKlemp
Copy link

MarvinKlemp commented Feb 3, 2024

@YangChen1234567 on waymo I can recreate the results.

I could try to recreate your NUS training. But I need some time to generate the GT for NUS.

OBJECT_TYPE_TYPE_VEHICLE_LEVEL_1: [mAP 0.751092] [mAPH 0.746194]
OBJECT_TYPE_TYPE_VEHICLE_LEVEL_2: [mAP 0.666377] [mAPH 0.661924]
OBJECT_TYPE_TYPE_PEDESTRIAN_LEVEL_1: [mAP 0.79377] [mAPH 0.718119]
OBJECT_TYPE_TYPE_PEDESTRIAN_LEVEL_2: [mAP 0.716397] [mAPH 0.64581]
OBJECT_TYPE_TYPE_SIGN_LEVEL_1: [mAP 0] [mAPH 0]
OBJECT_TYPE_TYPE_SIGN_LEVEL_2: [mAP 0] [mAPH 0]
OBJECT_TYPE_TYPE_CYCLIST_LEVEL_1: [mAP 0.716416] [mAPH 0.705113]
OBJECT_TYPE_TYPE_CYCLIST_LEVEL_2: [mAP 0.689657] [mAPH 0.678767]
RANGE_TYPE_VEHICLE_[0, 30)_LEVEL_1: [mAP 0.916769] [mAPH 0.912312]
RANGE_TYPE_VEHICLE_[0, 30)_LEVEL_2: [mAP 0.904136] [mAPH 0.899727]
RANGE_TYPE_VEHICLE_[30, 50)_LEVEL_1: [mAP 0.735015] [mAPH 0.729594]
RANGE_TYPE_VEHICLE_[30, 50)_LEVEL_2: [mAP 0.669861] [mAPH 0.664837]
RANGE_TYPE_VEHICLE_[50, +inf)_LEVEL_1: [mAP 0.511822] [mAPH 0.505538]
RANGE_TYPE_VEHICLE_[50, +inf)_LEVEL_2: [mAP 0.394505] [mAPH 0.389454]
RANGE_TYPE_PEDESTRIAN_[0, 30)_LEVEL_1: [mAP 0.842599] [mAPH 0.776197]
RANGE_TYPE_PEDESTRIAN_[0, 30)_LEVEL_2: [mAP 0.805566] [mAPH 0.740443]
RANGE_TYPE_PEDESTRIAN_[30, 50)_LEVEL_1: [mAP 0.78144] [mAPH 0.700494]
RANGE_TYPE_PEDESTRIAN_[30, 50)_LEVEL_2: [mAP 0.714023] [mAPH 0.63854]
RANGE_TYPE_PEDESTRIAN_[50, +inf)_LEVEL_1: [mAP 0.694307] [mAPH 0.595241]
RANGE_TYPE_PEDESTRIAN_[50, +inf)_LEVEL_2: [mAP 0.559821] [mAPH 0.476233]
RANGE_TYPE_SIGN_[0, 30)_LEVEL_1: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[0, 30)_LEVEL_2: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[30, 50)_LEVEL_1: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[30, 50)_LEVEL_2: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[50, +inf)_LEVEL_1: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[50, +inf)_LEVEL_2: [mAP 0] [mAPH 0]
RANGE_TYPE_CYCLIST_[0, 30)_LEVEL_1: [mAP 0.805891] [mAPH 0.795459]
RANGE_TYPE_CYCLIST_[0, 30)_LEVEL_2: [mAP 0.800108] [mAPH 0.789751]
RANGE_TYPE_CYCLIST_[30, 50)_LEVEL_1: [mAP 0.677743] [mAPH 0.666282]
RANGE_TYPE_CYCLIST_[30, 50)_LEVEL_2: [mAP 0.639797] [mAPH 0.628965]
RANGE_TYPE_CYCLIST_[50, +inf)_LEVEL_1: [mAP 0.54664] [mAPH 0.530787]
RANGE_TYPE_CYCLIST_[50, +inf)_LEVEL_2: [mAP 0.508892] [mAPH 0.494123]
Eval Using 216s

@MarvinKlemp
Copy link

MarvinKlemp commented Feb 7, 2024

@YangChen1234567

I trained it using your config.
However, I am unable to run the evaluation using your config

Formating bboxes of pred_instances_3d
Start to convert detection format...
[                                                  ] 0/6019, elapsed: 0s, ETA:Traceback (most recent call last):
  File "tools/test.py", line 149, in <module>
    main()
  File "tools/test.py", line 145, in main
    runner.test()
  File "/usr/local/lib/python3.8/dist-packages/mmengine/runner/runner.py", line 1823, in test
    metrics = self.test_loop.run()  # type: ignore
  File "/usr/local/lib/python3.8/dist-packages/mmengine/runner/loops.py", line 446, in run
    metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
  File "/usr/local/lib/python3.8/dist-packages/mmengine/evaluator/evaluator.py", line 79, in evaluate
    _results = metric.evaluate(size)
  File "/usr/local/lib/python3.8/dist-packages/mmengine/evaluator/metric.py", line 133, in evaluate
    _metrics = self.compute_metrics(results)  # type: ignore
  File "/workspace/mmdetection3d/mmdet3d/evaluation/metrics/nuscenes_metric.py", line 166, in compute_metrics
    result_dict, tmp_dir = self.format_results(results, classes,
  File "/workspace/mmdetection3d/mmdet3d/evaluation/metrics/nuscenes_metric.py", line 307, in format_results
    result_dict[name] = self._format_lidar_bbox(
  File "/workspace/mmdetection3d/mmdet3d/evaluation/metrics/nuscenes_metric.py", line 512, in _format_lidar_bbox
    boxes = lidar_nusc_box_to_global(self.data_infos[sample_idx],
  File "/workspace/mmdetection3d/mmdet3d/evaluation/metrics/nuscenes_metric.py", line 653, in lidar_nusc_box_to_global
    box.rotate(
  File "/usr/local/lib/python3.8/dist-packages/nuscenes/utils/data_classes.py", line 601, in rotate
    self.velocity = np.dot(quaternion.rotation_matrix, self.velocity)
  File "<__array_function__ internals>", line 180, in dot
ValueError: shapes (3,3) and (1,) not aligned: 3 (dim 1) != 1 (dim 0)

Did you encounter something similar?
I didn't notice it until the last epoch, as you disabled validation during training and i can eval other methods on NUS without such errors

EDIT:
Fixed it by adding code_size = 9 and a head for vel

@MarvinKlemp
Copy link

BTW: I can confirm that with your config I get quite bad results.

Evaluating bboxes of pred_instances_3d
mAP: 0.3820
mATE: 0.3414
mASE: 0.2582
mAOE: 0.3073
mAVE: 1.3762
mAAE: 0.4810
NDS: 0.4522
Eval time: 132.7s

Per-class results:
Object Class            AP      ATE     ASE     AOE     AVE     AAE
car                     0.753   0.181   0.152   0.098   1.878   0.479
truck                   0.402   0.292   0.173   0.125   1.198   0.391
bus                     0.474   0.327   0.177   0.064   2.334   0.655
trailer                 0.131   0.479   0.224   0.561   0.813   0.312
construction_vehicle    0.086   0.817   0.422   0.967   0.191   0.383
pedestrian              0.717   0.152   0.270   0.331   1.047   0.821
motorcycle              0.433   0.203   0.245   0.219   2.374   0.550
bicycle                 0.207   0.177   0.273   0.296   1.175   0.258
traffic_cone            0.406   0.179   0.337   nan     nan     nan
barrier                 0.210   0.606   0.308   0.105   nan     nan

@YangChen1234567
Copy link
Author

@YangChen1234567 on waymo I can recreate the results.

I could try to recreate your NUS training. But I need some time to generate the GT for NUS.

OBJECT_TYPE_TYPE_VEHICLE_LEVEL_1: [mAP 0.751092] [mAPH 0.746194]
OBJECT_TYPE_TYPE_VEHICLE_LEVEL_2: [mAP 0.666377] [mAPH 0.661924]
OBJECT_TYPE_TYPE_PEDESTRIAN_LEVEL_1: [mAP 0.79377] [mAPH 0.718119]
OBJECT_TYPE_TYPE_PEDESTRIAN_LEVEL_2: [mAP 0.716397] [mAPH 0.64581]
OBJECT_TYPE_TYPE_SIGN_LEVEL_1: [mAP 0] [mAPH 0]
OBJECT_TYPE_TYPE_SIGN_LEVEL_2: [mAP 0] [mAPH 0]
OBJECT_TYPE_TYPE_CYCLIST_LEVEL_1: [mAP 0.716416] [mAPH 0.705113]
OBJECT_TYPE_TYPE_CYCLIST_LEVEL_2: [mAP 0.689657] [mAPH 0.678767]
RANGE_TYPE_VEHICLE_[0, 30)_LEVEL_1: [mAP 0.916769] [mAPH 0.912312]
RANGE_TYPE_VEHICLE_[0, 30)_LEVEL_2: [mAP 0.904136] [mAPH 0.899727]
RANGE_TYPE_VEHICLE_[30, 50)_LEVEL_1: [mAP 0.735015] [mAPH 0.729594]
RANGE_TYPE_VEHICLE_[30, 50)_LEVEL_2: [mAP 0.669861] [mAPH 0.664837]
RANGE_TYPE_VEHICLE_[50, +inf)_LEVEL_1: [mAP 0.511822] [mAPH 0.505538]
RANGE_TYPE_VEHICLE_[50, +inf)_LEVEL_2: [mAP 0.394505] [mAPH 0.389454]
RANGE_TYPE_PEDESTRIAN_[0, 30)_LEVEL_1: [mAP 0.842599] [mAPH 0.776197]
RANGE_TYPE_PEDESTRIAN_[0, 30)_LEVEL_2: [mAP 0.805566] [mAPH 0.740443]
RANGE_TYPE_PEDESTRIAN_[30, 50)_LEVEL_1: [mAP 0.78144] [mAPH 0.700494]
RANGE_TYPE_PEDESTRIAN_[30, 50)_LEVEL_2: [mAP 0.714023] [mAPH 0.63854]
RANGE_TYPE_PEDESTRIAN_[50, +inf)_LEVEL_1: [mAP 0.694307] [mAPH 0.595241]
RANGE_TYPE_PEDESTRIAN_[50, +inf)_LEVEL_2: [mAP 0.559821] [mAPH 0.476233]
RANGE_TYPE_SIGN_[0, 30)_LEVEL_1: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[0, 30)_LEVEL_2: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[30, 50)_LEVEL_1: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[30, 50)_LEVEL_2: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[50, +inf)_LEVEL_1: [mAP 0] [mAPH 0]
RANGE_TYPE_SIGN_[50, +inf)_LEVEL_2: [mAP 0] [mAPH 0]
RANGE_TYPE_CYCLIST_[0, 30)_LEVEL_1: [mAP 0.805891] [mAPH 0.795459]
RANGE_TYPE_CYCLIST_[0, 30)_LEVEL_2: [mAP 0.800108] [mAPH 0.789751]
RANGE_TYPE_CYCLIST_[30, 50)_LEVEL_1: [mAP 0.677743] [mAPH 0.666282]
RANGE_TYPE_CYCLIST_[30, 50)_LEVEL_2: [mAP 0.639797] [mAPH 0.628965]
RANGE_TYPE_CYCLIST_[50, +inf)_LEVEL_1: [mAP 0.54664] [mAPH 0.530787]
RANGE_TYPE_CYCLIST_[50, +inf)_LEVEL_2: [mAP 0.508892] [mAPH 0.494123]
Eval Using 216s

Great!

@MarvinKlemp
Copy link

MarvinKlemp commented Feb 8, 2024

@JingweiZhang12 @Tai-Wang Any ideas on the NUS results?

(I've seen you strongly worked on support DSTV)

@YangChen1234567
Copy link
Author

BTW: I can confirm that with your config I get quite bad results.

Evaluating bboxes of pred_instances_3d
mAP: 0.3820
mATE: 0.3414
mASE: 0.2582
mAOE: 0.3073
mAVE: 1.3762
mAAE: 0.4810
NDS: 0.4522
Eval time: 132.7s

Per-class results:
Object Class            AP      ATE     ASE     AOE     AVE     AAE
car                     0.753   0.181   0.152   0.098   1.878   0.479
truck                   0.402   0.292   0.173   0.125   1.198   0.391
bus                     0.474   0.327   0.177   0.064   2.334   0.655
trailer                 0.131   0.479   0.224   0.561   0.813   0.312
construction_vehicle    0.086   0.817   0.422   0.967   0.191   0.383
pedestrian              0.717   0.152   0.270   0.331   1.047   0.821
motorcycle              0.433   0.203   0.245   0.219   2.374   0.550
bicycle                 0.207   0.177   0.273   0.296   1.175   0.258
traffic_cone            0.406   0.179   0.337   nan     nan     nan
barrier                 0.210   0.606   0.308   0.105   nan     nan

I'm here. Your training results is similar with mine. In DSVT, the CenterHead didn't predict velocity and can't using multiple "task head" to predict objects (if set multi task, there will be some mistakes), so I used official CenterHead to train DSVT on nuscenes datatset., but the mAP is still low and I don't know why......

@shuaizg
Copy link

shuaizg commented Oct 18, 2024

BTW: I can confirm that with your config I get quite bad results.

Evaluating bboxes of pred_instances_3d
mAP: 0.3820
mATE: 0.3414
mASE: 0.2582
mAOE: 0.3073
mAVE: 1.3762
mAAE: 0.4810
NDS: 0.4522
Eval time: 132.7s

Per-class results:
Object Class            AP      ATE     ASE     AOE     AVE     AAE
car                     0.753   0.181   0.152   0.098   1.878   0.479
truck                   0.402   0.292   0.173   0.125   1.198   0.391
bus                     0.474   0.327   0.177   0.064   2.334   0.655
trailer                 0.131   0.479   0.224   0.561   0.813   0.312
construction_vehicle    0.086   0.817   0.422   0.967   0.191   0.383
pedestrian              0.717   0.152   0.270   0.331   1.047   0.821
motorcycle              0.433   0.203   0.245   0.219   2.374   0.550
bicycle                 0.207   0.177   0.273   0.296   1.175   0.258
traffic_cone            0.406   0.179   0.337   nan     nan     nan
barrier                 0.210   0.606   0.308   0.105   nan     nan

I'm here. Your training results is similar with mine. In DSVT, the CenterHead didn't predict velocity and can't using multiple "task head" to predict objects (if set multi task, there will be some mistakes), so I used official CenterHead to train DSVT on nuscenes datatset., but the mAP is still low and I don't know why......

May I ask if there are any follow-up results? I've recently encountered this issue as well.

@MarvinKlemp
Copy link

@shuaizg I stopped working on it, on my custom dataset I get good results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants