diff --git a/projects/NeRF-Det/README.md b/projects/NeRF-Det/README.md new file mode 100644 index 0000000000..93119895e9 --- /dev/null +++ b/projects/NeRF-Det/README.md @@ -0,0 +1,115 @@ +# NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection + +> [NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection](https://arxiv.org/abs/2307.14620) + + + +## Abstract + +NeRF-Det is a novel method for indoor 3D detection with posed RGB images as input. Unlike existing indoor 3D detection methods that struggle to model scene geometry, NeRF-Det makes novel use of NeRF in an end-to-end manner to explicitly estimate 3D geometry, thereby improving 3D detection performance. Specifically, to avoid the significant extra latency associated with per-scene optimization of NeRF, NeRF-Det introduce sufficient geometry priors to enhance the generalizability of NeRF-MLP. Furthermore, it subtly connect the detection and NeRF branches through a shared MLP, enabling an efficient adaptation of NeRF to detection and yielding geometry-aware volumetric representations for 3D detection. NeRF-Det outperforms state-of-the-arts by 3.9 mAP and 3.1 mAP on the ScanNet and ARKITScenes benchmarks, respectively. The author provide extensive analysis to shed light on how NeRF-Det works. As a result of joint-training design, NeRF-Det is able to generalize well to unseen scenes for object detection, view synthesis, and depth estimation tasks without requiring per-scene optimization. Code will be available at https://github.com/facebookresearch/NeRF-Det + +
+ +
+ +## Introduction + +This directory contains the implementations of NeRF-Det (https://arxiv.org/abs/2307.14620). Our implementations are built on top of MMdetection3D. We have updated NeRF-Det to be compatible with latest mmdet3d version. The codebase and config files have all changed to adapt to the new mmdet3d version. All previous pretrained models are verified with the result listed below. However, newly trained models are yet to be uploaded. + + + +## Dataset + +The format of the scannet dataset in the latest version of mmdet3d only supports the lidar tasks. For NeRF-Det, we need to create the new format of ScanNet Dataset. + +Please following the files in mmdet3d to prepare the raw data of ScanNet. After that, please use this command to generate the pkls used in nerfdet. + +```bash +python projects/NeRF-Det/prepare_infos.py --root-path ./data/scannet --out-dir ./data/scannet +``` + +The new format of the pkl is organized as below: + +- scannet_infos_train.pkl: The train data infos, the detailed info of each scan is as follows: + - info\['instances'\]:A list of dict contains all annotations, each dict contains all annotation information of single instance.For the i-th instance: + - info\['instances'\]\[i\]\['bbox_3d'\]: List of 6 numbers representing the axis_aligned in depth coordinate system, in (x,y,z,l,w,h) order. + - info\['instances'\]\[i\]\['bbox_label_3d'\]: The label of each 3d bounding boxes. + - info\['cam2img'\]: The intrinsic matrix.Every scene has one matrix. + - info\['lidar2cam'\]: The extrinsic matrixes.Every scene has 300 matrixes. + - info\['img_paths'\]: The paths of the 300 rgb pictures. + - info\['axis_align_matrix'\]: The align matrix.Every scene has one matrix. + +After preparing your scannet dataset pkls,please change the paths in configs to fit your project. + +## Train + +In MMDet3D's root directory, run the following command to train the model: + +```bash +python tools/train.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${WORK_DIR} +``` + +## Results and Models + +### NeRF-Det + +| Backbone | mAP@25 | mAP@50 | Log | +| :-------------------------------------------------------------: | :----: | :----: | :-------: | +| [NeRF-Det-R50](./configs/nerfdet_res50_2x_low_res.py) | 53.0 | 26.8 | [log](<>) | +| [NeRF-Det-R50\*](./configs/nerfdet_res50_2x_low_res_depth.py) | 52.2 | 28.5 | [log](<>) | +| [NeRF-Det-R101\*](./configs/nerfdet_res101_2x_low_res_depth.py) | 52.3 | 28.5 | [log](<>) | + +(Here NeRF-Det-R50\* means this model uses depth information in the training step) + +### Notes + +- The values showed in the chart all represents the best mAP in the training. + +- Since there is a lot of randomness in the behavior of the model, we conducted three experiments on each config and took the average. The mAP showed on the above chart are all average values. + +- We also conducted the same experiments in the original code, the results are showed below. + + | Backbone | mAP@25 | mAP@50 | + | :-------------: | :----: | :----: | + | NeRF-Det-R50 | 52.8 | 26.8 | + | NeRF-Det-R50\* | 52.4 | 27.5 | + | NeRF-Det-R101\* | 52.8 | 28.6 | + +- Attention: Because of the randomness in the construction of the ScanNet dataset itself and the behavior of the model, the training results will fluctuate considerably. According to experimental results and experience, the experimental results will fluctuate by plus or minus 1.5 points. + +## Evaluation using pretrained models + +1. Download the pretrained checkpoints through the linkings in the above chart. + +2. Testing + + To test, use: + + ```bash + python tools/test.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${CHECKPOINT_PATH} + ``` + +## Citation + + + +```latex +@inproceedings{ + xu2023nerfdet, + title={NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection}, + author={Xu, Chenfeng and Wu, Bichen and Hou, Ji and Tsai, Sam and Li, Ruilong and Wang, Jialiang and Zhan, Wei and He, Zijian and Vajda, Peter and Keutzer, Kurt and Tomizuka, Masayoshi}, + booktitle={ICCV}, + year={2023}, +} + +@inproceedings{ +park2023time, +title={Time Will Tell: New Outlooks and A Baseline for Temporal Multi-View 3D Object Detection}, +author={Jinhyung Park and Chenfeng Xu and Shijia Yang and Kurt Keutzer and Kris M. Kitani and Masayoshi Tomizuka and Wei Zhan}, +booktitle={The Eleventh International Conference on Learning Representations }, +year={2023}, +url={https://openreview.net/forum?id=H3HcEJA2Um} +} +``` diff --git a/projects/NeRF-Det/configs/nerfdet_res101_2x_low_res_depth.py b/projects/NeRF-Det/configs/nerfdet_res101_2x_low_res_depth.py new file mode 100644 index 0000000000..b3c639f19e --- /dev/null +++ b/projects/NeRF-Det/configs/nerfdet_res101_2x_low_res_depth.py @@ -0,0 +1,198 @@ +_base_ = ['../../../configs/_base_/default_runtime.py'] + +custom_imports = dict(imports=['projects.NeRF-Det.nerfdet']) +prior_generator = dict( + type='AlignedAnchor3DRangeGenerator', + ranges=[[-3.2, -3.2, -1.28, 3.2, 3.2, 1.28]], + rotations=[.0]) + +model = dict( + type='NerfDet', + data_preprocessor=dict( + type='NeRFDetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=10), + backbone=dict( + type='mmdet.ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet101'), + style='pytorch'), + neck=dict( + type='mmdet.FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + neck_3d=dict( + type='IndoorImVoxelNeck', + in_channels=256, + out_channels=128, + n_blocks=[1, 1, 1]), + bbox_head=dict( + type='NerfDetHead', + bbox_loss=dict(type='AxisAlignedIoULoss', loss_weight=1.0), + n_classes=18, + n_levels=3, + n_channels=128, + n_reg_outs=6, + pts_assign_threshold=27, + pts_center_threshold=18, + prior_generator=prior_generator), + prior_generator=prior_generator, + voxel_size=[.16, .16, .2], + n_voxels=[40, 40, 16], + aabb=([-2.7, -2.7, -0.78], [3.7, 3.7, 1.78]), + near_far_range=[0.2, 8.0], + N_samples=64, + N_rand=2048, + nerf_mode='image', + depth_supervise=True, + use_nerf_mask=True, + nerf_sample_view=20, + squeeze_scale=4, + nerf_density=True, + train_cfg=dict(), + test_cfg=dict(nms_pre=1000, iou_thr=.25, score_thr=.01)) + +dataset_type = 'MultiViewScanNetDataset' +data_root = 'data/scannet/' +class_names = [ + 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', + 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain', + 'toilet', 'sink', 'bathtub', 'garbagebin' +] +metainfo = dict(CLASSES=class_names) +file_client_args = dict(backend='disk') + +input_modality = dict( + use_camera=True, + use_depth=True, + use_lidar=False, + use_neuralrecon_depth=False, + use_ray=True) +backend_args = None + +train_collect_keys = [ + 'img', 'gt_bboxes_3d', 'gt_labels_3d', 'depth', 'lightpos', 'nerf_sizes', + 'raydirs', 'gt_images', 'gt_depths', 'denorm_images' +] + +test_collect_keys = [ + 'img', + 'depth', + 'lightpos', + 'nerf_sizes', + 'raydirs', + 'gt_images', + 'gt_depths', + 'denorm_images', +] + +train_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=48, + transforms=[ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(320, 240), keep_ratio=True), + ], + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + margin=10, + depth_range=[0.5, 5.5], + loading='random', + nerf_target_views=10), + dict(type='RandomShiftOrigin', std=(.7, .7, .0)), + dict(type='PackNeRFDetInputs', keys=train_collect_keys) +] + +test_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=101, + transforms=[ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(320, 240), keep_ratio=True), + ], + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + margin=10, + depth_range=[0.5, 5.5], + loading='random', + nerf_target_views=1), + dict(type='PackNeRFDetInputs', keys=test_collect_keys) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='RepeatDataset', + times=6, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='scannet_infos_train_new.pkl', + pipeline=train_pipeline, + modality=input_modality, + test_mode=False, + filter_empty_gt=True, + box_type_3d='Depth', + metainfo=metainfo))) +val_dataloader = dict( + batch_size=1, + num_workers=5, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='scannet_infos_val_new.pkl', + pipeline=test_pipeline, + modality=input_modality, + test_mode=True, + filter_empty_gt=True, + box_type_3d='Depth', + metainfo=metainfo)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IndoorMetric') +test_evaluator = val_evaluator + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) +test_cfg = dict() +val_cfg = dict() + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001), + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}), + clip_grad=dict(max_norm=35., norm_type=2)) +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# hooks +default_hooks = dict( + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=12)) + +# runtime +find_unused_parameters = True # only 1 of 4 FPN outputs is used diff --git a/projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py b/projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py new file mode 100644 index 0000000000..0321d54bba --- /dev/null +++ b/projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py @@ -0,0 +1,104 @@ +_base_ = ['./nerfdet_res50_2x_low_res_depth.py'] + +model = dict(depth_supervise=False) + +dataset_type = 'MultiViewScanNetDataset' +data_root = 'data/scannet/' +class_names = [ + 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', + 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain', + 'toilet', 'sink', 'bathtub', 'garbagebin' +] +metainfo = dict(CLASSES=class_names) +file_client_args = dict(backend='disk') + +input_modality = dict(use_depth=False) +backend_args = None + +train_collect_keys = [ + 'img', 'gt_bboxes_3d', 'gt_labels_3d', 'lightpos', 'nerf_sizes', 'raydirs', + 'gt_images', 'gt_depths', 'denorm_images' +] + +test_collect_keys = [ + 'img', + 'lightpos', + 'nerf_sizes', + 'raydirs', + 'gt_images', + 'gt_depths', + 'denorm_images', +] + +train_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=50, + transforms=[ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(320, 240), keep_ratio=True), + ], + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + margin=10, + depth_range=[0.5, 5.5], + loading='random', + nerf_target_views=10), + dict(type='RandomShiftOrigin', std=(.7, .7, .0)), + dict(type='PackNeRFDetInputs', keys=train_collect_keys) +] + +test_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=101, + transforms=[ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(320, 240), keep_ratio=True), + ], + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + margin=10, + depth_range=[0.5, 5.5], + loading='random', + nerf_target_views=1), + dict(type='PackNeRFDetInputs', keys=test_collect_keys) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='RepeatDataset', + times=6, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='scannet_infos_train_new.pkl', + pipeline=train_pipeline, + modality=input_modality, + test_mode=False, + filter_empty_gt=True, + box_type_3d='Depth', + metainfo=metainfo))) +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='scannet_infos_val_new.pkl', + pipeline=test_pipeline, + modality=input_modality, + test_mode=True, + filter_empty_gt=True, + box_type_3d='Depth', + metainfo=metainfo)) +test_dataloader = val_dataloader diff --git a/projects/NeRF-Det/configs/nerfdet_res50_2x_low_res_depth.py b/projects/NeRF-Det/configs/nerfdet_res50_2x_low_res_depth.py new file mode 100644 index 0000000000..0143a8084a --- /dev/null +++ b/projects/NeRF-Det/configs/nerfdet_res50_2x_low_res_depth.py @@ -0,0 +1,198 @@ +_base_ = ['../../../configs/_base_/default_runtime.py'] + +custom_imports = dict(imports=['projects.NeRF-Det.nerfdet']) +prior_generator = dict( + type='AlignedAnchor3DRangeGenerator', + ranges=[[-3.2, -3.2, -1.28, 3.2, 3.2, 1.28]], + rotations=[.0]) + +model = dict( + type='NerfDet', + data_preprocessor=dict( + type='NeRFDetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=10), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + style='pytorch'), + neck=dict( + type='mmdet.FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + neck_3d=dict( + type='IndoorImVoxelNeck', + in_channels=256, + out_channels=128, + n_blocks=[1, 1, 1]), + bbox_head=dict( + type='NerfDetHead', + bbox_loss=dict(type='AxisAlignedIoULoss', loss_weight=1.0), + n_classes=18, + n_levels=3, + n_channels=128, + n_reg_outs=6, + pts_assign_threshold=27, + pts_center_threshold=18, + prior_generator=prior_generator), + prior_generator=prior_generator, + voxel_size=[.16, .16, .2], + n_voxels=[40, 40, 16], + aabb=([-2.7, -2.7, -0.78], [3.7, 3.7, 1.78]), + near_far_range=[0.2, 8.0], + N_samples=64, + N_rand=2048, + nerf_mode='image', + depth_supervise=True, + use_nerf_mask=True, + nerf_sample_view=20, + squeeze_scale=4, + nerf_density=True, + train_cfg=dict(), + test_cfg=dict(nms_pre=1000, iou_thr=.25, score_thr=.01)) + +dataset_type = 'MultiViewScanNetDataset' +data_root = 'data/scannet/' +class_names = [ + 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', + 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain', + 'toilet', 'sink', 'bathtub', 'garbagebin' +] +metainfo = dict(CLASSES=class_names) +file_client_args = dict(backend='disk') + +input_modality = dict( + use_camera=True, + use_depth=True, + use_lidar=False, + use_neuralrecon_depth=False, + use_ray=True) +backend_args = None + +train_collect_keys = [ + 'img', 'gt_bboxes_3d', 'gt_labels_3d', 'depth', 'lightpos', 'nerf_sizes', + 'raydirs', 'gt_images', 'gt_depths', 'denorm_images' +] + +test_collect_keys = [ + 'img', + 'depth', + 'lightpos', + 'nerf_sizes', + 'raydirs', + 'gt_images', + 'gt_depths', + 'denorm_images', +] + +train_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=50, + transforms=[ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(320, 240), keep_ratio=True), + ], + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + margin=10, + depth_range=[0.5, 5.5], + loading='random', + nerf_target_views=10), + dict(type='RandomShiftOrigin', std=(.7, .7, .0)), + dict(type='PackNeRFDetInputs', keys=train_collect_keys) +] + +test_pipeline = [ + dict(type='LoadAnnotations3D'), + dict( + type='MultiViewPipeline', + n_images=101, + transforms=[ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(320, 240), keep_ratio=True), + ], + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + margin=10, + depth_range=[0.5, 5.5], + loading='random', + nerf_target_views=1), + dict(type='PackNeRFDetInputs', keys=test_collect_keys) +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='RepeatDataset', + times=6, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='scannet_infos_train_new.pkl', + pipeline=train_pipeline, + modality=input_modality, + test_mode=False, + filter_empty_gt=True, + box_type_3d='Depth', + metainfo=metainfo))) +val_dataloader = dict( + batch_size=1, + num_workers=5, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='scannet_infos_val_new.pkl', + pipeline=test_pipeline, + modality=input_modality, + test_mode=True, + filter_empty_gt=True, + box_type_3d='Depth', + metainfo=metainfo)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IndoorMetric') +test_evaluator = val_evaluator + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) +test_cfg = dict() +val_cfg = dict() + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001), + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}), + clip_grad=dict(max_norm=35., norm_type=2)) +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# hooks +default_hooks = dict( + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=12)) + +# runtime +find_unused_parameters = True # only 1 of 4 FPN outputs is used diff --git a/projects/NeRF-Det/nerfdet/__init__.py b/projects/NeRF-Det/nerfdet/__init__.py new file mode 100644 index 0000000000..5ddef2f7be --- /dev/null +++ b/projects/NeRF-Det/nerfdet/__init__.py @@ -0,0 +1,11 @@ +from .data_preprocessor import NeRFDetDataPreprocessor +from .formating import PackNeRFDetInputs +from .multiview_pipeline import MultiViewPipeline, RandomShiftOrigin +from .nerfdet import NerfDet +from .nerfdet_head import NerfDetHead +from .scannet_multiview_dataset import MultiViewScanNetDataset + +__all__ = [ + 'MultiViewScanNetDataset', 'MultiViewPipeline', 'RandomShiftOrigin', + 'PackNeRFDetInputs', 'NeRFDetDataPreprocessor', 'NerfDetHead', 'NerfDet' +] diff --git a/projects/NeRF-Det/nerfdet/data_preprocessor.py b/projects/NeRF-Det/nerfdet/data_preprocessor.py new file mode 100644 index 0000000000..582a09f63c --- /dev/null +++ b/projects/NeRF-Det/nerfdet/data_preprocessor.py @@ -0,0 +1,583 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from numbers import Number +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from mmdet.models import DetDataPreprocessor +from mmdet.models.utils.misc import samplelist_boxtype2tensor +from mmengine.model import stack_batch +from mmengine.utils import is_seq_of +from torch import Tensor +from torch.nn import functional as F + +from mmdet3d.models.data_preprocessors.utils import multiview_img_stack_batch +from mmdet3d.models.data_preprocessors.voxelize import ( + VoxelizationByGridShape, dynamic_scatter_3d) +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils import OptConfigType + + +@MODELS.register_module() +class NeRFDetDataPreprocessor(DetDataPreprocessor): + """In NeRF-Det, some extra information is needed in NeRF branch. We put the + datapreprocessor operations of these new information such as stack and pack + operations in this class. You can find the stack operations in subfuction + 'collate_data' and the pack operations in 'simple_process'. Other codes are + the same as the default class 'DetDataPreprocessor'. + + Points / Image pre-processor for point clouds / vision-only / multi- + modality 3D detection tasks. + + It provides the data pre-processing as follows + + - Collate and move image and point cloud data to the target device. + + - 1) For image data: + + - Pad images in inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor``. + - Stack images in inputs to batch_imgs. + - Convert images in inputs from bgr to rgb if the shape of input is + (3, H, W). + - Normalize images in inputs with defined std and mean. + - Do batch augmentations during training. + + - 2) For point cloud data: + + - If no voxelization, directly return list of point cloud data. + - If voxelization is applied, voxelize point cloud according to + ``voxel_type`` and obtain ``voxels``. + + Args: + voxel (bool): Whether to apply voxelization to point cloud. + Defaults to False. + voxel_type (str): Voxelization type. Two voxelization types are + provided: 'hard' and 'dynamic', respectively for hard voxelization + and dynamic voxelization. Defaults to 'hard'. + voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer + config. Defaults to None. + batch_first (bool): Whether to put the batch dimension to the first + dimension when getting voxel coordinates. Defaults to True. + max_voxels (int, optional): Maximum number of voxels in each voxel + grid. Defaults to None. + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be divisible by + ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + pad_mask (bool): Whether to pad instance masks. Defaults to False. + mask_pad_value (int): The padded pixel value for instance masks. + Defaults to 0. + pad_seg (bool): Whether to pad semantic segmentation maps. + Defaults to False. + seg_pad_value (int): The padded pixel value for semantic segmentation + maps. Defaults to 255. + bgr_to_rgb (bool): Whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): Whether to convert image from RGB to BGR. + Defaults to False. + boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of + bboxes data to ``Tensor`` type. Defaults to True. + non_blocking (bool): Whether to block current process when transferring + data to device. Defaults to False. + batch_augments (List[dict], optional): Batch-level augmentations. + Defaults to None. + """ + + def __init__(self, + voxel: bool = False, + voxel_type: str = 'hard', + voxel_layer: OptConfigType = None, + batch_first: bool = True, + max_voxels: Optional[int] = None, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + pad_mask: bool = False, + mask_pad_value: int = 0, + pad_seg: bool = False, + seg_pad_value: int = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + boxtype2tensor: bool = True, + non_blocking: bool = False, + batch_augments: Optional[List[dict]] = None) -> None: + super(NeRFDetDataPreprocessor, self).__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + pad_mask=pad_mask, + mask_pad_value=mask_pad_value, + pad_seg=pad_seg, + seg_pad_value=seg_pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + boxtype2tensor=boxtype2tensor, + non_blocking=non_blocking, + batch_augments=batch_augments) + self.voxel = voxel + self.voxel_type = voxel_type + self.batch_first = batch_first + self.max_voxels = max_voxels + if voxel: + self.voxel_layer = VoxelizationByGridShape(**voxel_layer) + + def forward(self, + data: Union[dict, List[dict]], + training: bool = False) -> Union[dict, List[dict]]: + """Perform normalization, padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict or List[dict]): Data from dataloader. The dict contains + the whole batch data, when it is a list[dict], the list + indicates test time augmentation. + training (bool): Whether to enable training time augmentation. + Defaults to False. + + Returns: + dict or List[dict]: Data in the same format as the model input. + """ + if isinstance(data, list): + num_augs = len(data) + aug_batch_data = [] + for aug_id in range(num_augs): + single_aug_batch_data = self.simple_process( + data[aug_id], training) + aug_batch_data.append(single_aug_batch_data) + return aug_batch_data + + else: + return self.simple_process(data, training) + + def simple_process(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding and bgr2rgb conversion for img data + based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel` + is set to be True. + + Args: + data (dict): Data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + Defaults to False. + + Returns: + dict: Data in the same format as the model input. + """ + if 'img' in data['inputs']: + batch_pad_shape = self._get_pad_shape(data) + + data = self.collate_data(data) + inputs, data_samples = data['inputs'], data['data_samples'] + batch_inputs = dict() + + if 'points' in inputs: + batch_inputs['points'] = inputs['points'] + + if self.voxel: + voxel_dict = self.voxelize(inputs['points'], data_samples) + batch_inputs['voxels'] = voxel_dict + + if 'imgs' in inputs: + imgs = inputs['imgs'] + + if data_samples is not None: + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which + # is then used for the transformer_head. + batch_input_shape = tuple(imgs[0].size()[-2:]) + for data_sample, pad_shape in zip(data_samples, + batch_pad_shape): + data_sample.set_metainfo({ + 'batch_input_shape': batch_input_shape, + 'pad_shape': pad_shape + }) + + if self.boxtype2tensor: + samplelist_boxtype2tensor(data_samples) + if self.pad_mask: + self.pad_gt_masks(data_samples) + if self.pad_seg: + self.pad_gt_sem_seg(data_samples) + + if training and self.batch_augments is not None: + for batch_aug in self.batch_augments: + imgs, data_samples = batch_aug(imgs, data_samples) + batch_inputs['imgs'] = imgs + # Hard code here, will be changed later. + # if len(inputs['depth']) != 0: + if 'depth' in inputs.keys(): + batch_inputs['depth'] = inputs['depth'] + batch_inputs['lightpos'] = inputs['lightpos'] + batch_inputs['nerf_sizes'] = inputs['nerf_sizes'] + batch_inputs['denorm_images'] = inputs['denorm_images'] + batch_inputs['raydirs'] = inputs['raydirs'] + + return {'inputs': batch_inputs, 'data_samples': data_samples} + + def preprocess_img(self, _batch_img: Tensor) -> Tensor: + # channel transform + if self._channel_conversion: + _batch_img = _batch_img[[2, 1, 0], ...] + # Convert to float after channel conversion to ensure + # efficiency + _batch_img = _batch_img.float() + # Normalization. + if self._enable_normalize: + if self.mean.shape[0] == 3: + assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, ( + 'If the mean has 3 values, the input tensor ' + 'should in shape of (3, H, W), but got the ' + f'tensor with shape {_batch_img.shape}') + _batch_img = (_batch_img - self.mean) / self.std + return _batch_img + + def collate_data(self, data: dict) -> dict: + """Copy data to the target device and perform normalization, padding + and bgr2rgb conversion and stack based on ``BaseDataPreprocessor``. + + Collates the data sampled from dataloader into a list of dict and list + of labels, and then copies tensor to the target device. + + Args: + data (dict): Data sampled from dataloader. + + Returns: + dict: Data in the same format as the model input. + """ + data = self.cast_data(data) # type: ignore + + if 'img' in data['inputs']: + _batch_imgs = data['inputs']['img'] + # Process data with `pseudo_collate`. + if is_seq_of(_batch_imgs, torch.Tensor): + batch_imgs = [] + img_dim = _batch_imgs[0].dim() + for _batch_img in _batch_imgs: + if img_dim == 3: # standard img + _batch_img = self.preprocess_img(_batch_img) + elif img_dim == 4: + _batch_img = [ + self.preprocess_img(_img) for _img in _batch_img + ] + + _batch_img = torch.stack(_batch_img, dim=0) + + batch_imgs.append(_batch_img) + + # Pad and stack Tensor. + if img_dim == 3: + batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor, + self.pad_value) + elif img_dim == 4: + batch_imgs = multiview_img_stack_batch( + batch_imgs, self.pad_size_divisor, self.pad_value) + + # Process data with `default_collate`. + elif isinstance(_batch_imgs, torch.Tensor): + assert _batch_imgs.dim() == 4, ( + 'The input of `ImgDataPreprocessor` should be a NCHW ' + 'tensor or a list of tensor, but got a tensor with ' + f'shape: {_batch_imgs.shape}') + if self._channel_conversion: + _batch_imgs = _batch_imgs[:, [2, 1, 0], ...] + # Convert to float after channel conversion to ensure + # efficiency + _batch_imgs = _batch_imgs.float() + if self._enable_normalize: + _batch_imgs = (_batch_imgs - self.mean) / self.std + h, w = _batch_imgs.shape[2:] + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h), + 'constant', self.pad_value) + else: + raise TypeError( + 'Output of `cast_data` should be a list of dict ' + 'or a tuple with inputs and data_samples, but got ' + f'{type(data)}: {data}') + + data['inputs']['imgs'] = batch_imgs + if 'raydirs' in data['inputs']: + _batch_dirs = data['inputs']['raydirs'] + batch_dirs = stack_batch(_batch_dirs) + data['inputs']['raydirs'] = batch_dirs + + if 'lightpos' in data['inputs']: + _batch_poses = data['inputs']['lightpos'] + batch_poses = stack_batch(_batch_poses) + data['inputs']['lightpos'] = batch_poses + + if 'denorm_images' in data['inputs']: + _batch_denorm_imgs = data['inputs']['denorm_images'] + # Process data with `pseudo_collate`. + if is_seq_of(_batch_denorm_imgs, torch.Tensor): + denorm_img_dim = _batch_denorm_imgs[0].dim() + # Pad and stack Tensor. + if denorm_img_dim == 3: + batch_denorm_imgs = stack_batch(_batch_denorm_imgs, + self.pad_size_divisor, + self.pad_value) + elif denorm_img_dim == 4: + batch_denorm_imgs = multiview_img_stack_batch( + _batch_denorm_imgs, self.pad_size_divisor, + self.pad_value) + data['inputs']['denorm_images'] = batch_denorm_imgs + + data.setdefault('data_samples', None) + + return data + + def _get_pad_shape(self, data: dict) -> List[Tuple[int, int]]: + """Get the pad_shape of each image based on data and + pad_size_divisor.""" + # rewrite `_get_pad_shape` for obtaining image inputs. + _batch_inputs = data['inputs']['img'] + # Process data with `pseudo_collate`. + if is_seq_of(_batch_inputs, torch.Tensor): + batch_pad_shape = [] + for ori_input in _batch_inputs: + if ori_input.dim() == 4: + # mean multiview input, select one of the + # image to calculate the pad shape + ori_input = ori_input[0] + pad_h = int( + np.ceil(ori_input.shape[1] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_w = int( + np.ceil(ori_input.shape[2] / + self.pad_size_divisor)) * self.pad_size_divisor + batch_pad_shape.append((pad_h, pad_w)) + # Process data with `default_collate`. + elif isinstance(_batch_inputs, torch.Tensor): + assert _batch_inputs.dim() == 4, ( + 'The input of `ImgDataPreprocessor` should be a NCHW tensor ' + 'or a list of tensor, but got a tensor with shape: ' + f'{_batch_inputs.shape}') + pad_h = int( + np.ceil(_batch_inputs.shape[1] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_w = int( + np.ceil(_batch_inputs.shape[2] / + self.pad_size_divisor)) * self.pad_size_divisor + batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0] + else: + raise TypeError('Output of `cast_data` should be a list of dict ' + 'or a tuple with inputs and data_samples, but got ' + f'{type(data)}: {data}') + return batch_pad_shape + + @torch.no_grad() + def voxelize(self, points: List[Tensor], + data_samples: SampleList) -> Dict[str, Tensor]: + """Apply voxelization to point cloud. + + Args: + points (List[Tensor]): Point cloud in one data batch. + data_samples: (list[:obj:`NeRFDet3DDataSample`]): The annotation + data of every samples. Add voxel-wise annotation for + segmentation. + + Returns: + Dict[str, Tensor]: Voxelization information. + + - voxels (Tensor): Features of voxels, shape is MxNxC for hard + voxelization, NxC for dynamic voxelization. + - coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim), + where 1 represents the batch index. + - num_points (Tensor, optional): Number of points in each voxel. + - voxel_centers (Tensor, optional): Centers of voxels. + """ + + voxel_dict = dict() + + if self.voxel_type == 'hard': + voxels, coors, num_points, voxel_centers = [], [], [], [] + for i, res in enumerate(points): + res_voxels, res_coors, res_num_points = self.voxel_layer(res) + res_voxel_centers = ( + res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor( + self.voxel_layer.voxel_size) + res_voxels.new_tensor( + self.voxel_layer.point_cloud_range[0:3]) + res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i) + voxels.append(res_voxels) + coors.append(res_coors) + num_points.append(res_num_points) + voxel_centers.append(res_voxel_centers) + + voxels = torch.cat(voxels, dim=0) + coors = torch.cat(coors, dim=0) + num_points = torch.cat(num_points, dim=0) + voxel_centers = torch.cat(voxel_centers, dim=0) + + voxel_dict['num_points'] = num_points + voxel_dict['voxel_centers'] = voxel_centers + elif self.voxel_type == 'dynamic': + coors = [] + # dynamic voxelization only provide a coors mapping + for i, res in enumerate(points): + res_coors = self.voxel_layer(res) + res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i) + coors.append(res_coors) + voxels = torch.cat(points, dim=0) + coors = torch.cat(coors, dim=0) + elif self.voxel_type == 'cylindrical': + voxels, coors = [], [] + for i, (res, data_sample) in enumerate(zip(points, data_samples)): + rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2) + phi = torch.atan2(res[:, 1], res[:, 0]) + polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1) + min_bound = polar_res.new_tensor( + self.voxel_layer.point_cloud_range[:3]) + max_bound = polar_res.new_tensor( + self.voxel_layer.point_cloud_range[3:]) + try: # only support PyTorch >= 1.9.0 + polar_res_clamp = torch.clamp(polar_res, min_bound, + max_bound) + except TypeError: + polar_res_clamp = polar_res.clone() + for coor_idx in range(3): + polar_res_clamp[:, coor_idx][ + polar_res[:, coor_idx] > + max_bound[coor_idx]] = max_bound[coor_idx] + polar_res_clamp[:, coor_idx][ + polar_res[:, coor_idx] < + min_bound[coor_idx]] = min_bound[coor_idx] + res_coors = torch.floor( + (polar_res_clamp - min_bound) / polar_res_clamp.new_tensor( + self.voxel_layer.voxel_size)).int() + self.get_voxel_seg(res_coors, data_sample) + res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i) + res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]), + dim=-1) + voxels.append(res_voxels) + coors.append(res_coors) + voxels = torch.cat(voxels, dim=0) + coors = torch.cat(coors, dim=0) + elif self.voxel_type == 'minkunet': + voxels, coors = [], [] + voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size) + for i, (res, data_sample) in enumerate(zip(points, data_samples)): + res_coors = torch.round(res[:, :3] / voxel_size).int() + res_coors -= res_coors.min(0)[0] + + res_coors_numpy = res_coors.cpu().numpy() + inds, point2voxel_map = self.sparse_quantize( + res_coors_numpy, return_index=True, return_inverse=True) + point2voxel_map = torch.from_numpy(point2voxel_map).cuda() + if self.training and self.max_voxels is not None: + if len(inds) > self.max_voxels: + inds = np.random.choice( + inds, self.max_voxels, replace=False) + inds = torch.from_numpy(inds).cuda() + if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'): + data_sample.gt_pts_seg.voxel_semantic_mask \ + = data_sample.gt_pts_seg.pts_semantic_mask[inds] + res_voxel_coors = res_coors[inds] + res_voxels = res[inds] + if self.batch_first: + res_voxel_coors = F.pad( + res_voxel_coors, (1, 0), mode='constant', value=i) + data_sample.batch_idx = res_voxel_coors[:, 0] + else: + res_voxel_coors = F.pad( + res_voxel_coors, (0, 1), mode='constant', value=i) + data_sample.batch_idx = res_voxel_coors[:, -1] + data_sample.point2voxel_map = point2voxel_map.long() + voxels.append(res_voxels) + coors.append(res_voxel_coors) + voxels = torch.cat(voxels, dim=0) + coors = torch.cat(coors, dim=0) + + else: + raise ValueError(f'Invalid voxelization type {self.voxel_type}') + + voxel_dict['voxels'] = voxels + voxel_dict['coors'] = coors + + return voxel_dict + + def get_voxel_seg(self, res_coors: Tensor, + data_sample: SampleList) -> None: + """Get voxel-wise segmentation label and point2voxel map. + + Args: + res_coors (Tensor): The voxel coordinates of points, Nx3. + data_sample: (:obj:`NeRFDet3DDataSample`): The annotation data of + every samples. Add voxel-wise annotation forsegmentation. + """ + + if self.training: + pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask + voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d( + F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean', + True) + voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1) + data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask + data_sample.point2voxel_map = point2voxel_map + else: + pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float() + _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor, + res_coors, 'mean', True) + data_sample.point2voxel_map = point2voxel_map + + def ravel_hash(self, x: np.ndarray) -> np.ndarray: + """Get voxel coordinates hash for np.unique. + + Args: + x (np.ndarray): The voxel coordinates of points, Nx3. + + Returns: + np.ndarray: Voxels coordinates hash. + """ + assert x.ndim == 2, x.shape + + x = x - np.min(x, axis=0) + x = x.astype(np.uint64, copy=False) + xmax = np.max(x, axis=0).astype(np.uint64) + 1 + + h = np.zeros(x.shape[0], dtype=np.uint64) + for k in range(x.shape[1] - 1): + h += x[:, k] + h *= xmax[k + 1] + h += x[:, -1] + return h + + def sparse_quantize(self, + coords: np.ndarray, + return_index: bool = False, + return_inverse: bool = False) -> List[np.ndarray]: + """Sparse Quantization for voxel coordinates used in Minkunet. + + Args: + coords (np.ndarray): The voxel coordinates of points, Nx3. + return_index (bool): Whether to return the indices of the unique + coords, shape (M,). + return_inverse (bool): Whether to return the indices of the + original coords, shape (N,). + + Returns: + List[np.ndarray]: Return index and inverse map if return_index and + return_inverse is True. + """ + _, indices, inverse_indices = np.unique( + self.ravel_hash(coords), return_index=True, return_inverse=True) + coords = coords[indices] + + outputs = [] + if return_index: + outputs += [indices] + if return_inverse: + outputs += [inverse_indices] + return outputs diff --git a/projects/NeRF-Det/nerfdet/formating.py b/projects/NeRF-Det/nerfdet/formating.py new file mode 100644 index 0000000000..6063d634cf --- /dev/null +++ b/projects/NeRF-Det/nerfdet/formating.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Union + +import mmengine +import numpy as np +import torch +from mmcv import BaseTransform +from mmengine.structures import InstanceData +from numpy import dtype + +from mmdet3d.registry import TRANSFORMS +from mmdet3d.structures import BaseInstance3DBoxes, PointData +from mmdet3d.structures.points import BasePoints +# from .det3d_data_sample import Det3DDataSample +from .nerf_det3d_data_sample import NeRFDet3DDataSample + + +def to_tensor( + data: Union[torch.Tensor, np.ndarray, Sequence, int, + float]) -> torch.Tensor: + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + + Returns: + torch.Tensor: the converted data. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + if data.dtype is dtype('float64'): + data = data.astype(np.float32) + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmengine.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@TRANSFORMS.register_module() +class PackNeRFDetInputs(BaseTransform): + INPUTS_KEYS = ['points', 'img'] + NERF_INPUT_KEYS = [ + 'img', 'denorm_images', 'depth', 'lightpos', 'nerf_sizes', 'raydirs' + ] + + INSTANCEDATA_3D_KEYS = [ + 'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d' + ] + INSTANCEDATA_2D_KEYS = [ + 'gt_bboxes', + 'gt_bboxes_labels', + ] + NERF_3D_KEYS = ['gt_images', 'gt_depths'] + + SEG_KEYS = [ + 'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask', + 'gt_semantic_seg' + ] + + def __init__( + self, + keys: tuple, + meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img', + 'depth2img', 'cam2img', 'pad_shape', + 'scale_factor', 'flip', 'pcd_horizontal_flip', + 'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', + 'img_norm_cfg', 'num_pts_feats', 'pcd_trans', + 'sample_idx', 'pcd_scale_factor', 'pcd_rotation', + 'pcd_rotation_angle', 'lidar_path', + 'transformation_3d_flow', 'trans_mat', + 'affine_aug', 'sweep_img_metas', 'ori_cam2img', + 'cam2global', 'crop_offset', 'img_crop_offset', + 'resize_img_shape', 'lidar2cam', 'ori_lidar2img', + 'num_ref_frames', 'num_views', 'ego2global', + 'axis_align_matrix') + ) -> None: + self.keys = keys + self.meta_keys = meta_keys + + def _remove_prefix(self, key: str) -> str: + if key.startswith('gt_'): + key = key[3:] + return key + + def transform(self, results: Union[dict, + List[dict]]) -> Union[dict, List[dict]]: + """Method to pack the input data. when the value in this dict is a + list, it usually is in Augmentations Testing. + + Args: + results (dict | list[dict]): Result dict from the data pipeline. + + Returns: + dict | List[dict]: + + - 'inputs' (dict): The forward data of models. It usually contains + following keys: + + - points + - img + + - 'data_samples' (:obj:`NeRFDet3DDataSample`): The annotation info + of the sample. + """ + # augtest + if isinstance(results, list): + if len(results) == 1: + # simple test + return self.pack_single_results(results[0]) + pack_results = [] + for single_result in results: + pack_results.append(self.pack_single_results(single_result)) + return pack_results + # norm training and simple testing + elif isinstance(results, dict): + return self.pack_single_results(results) + else: + raise NotImplementedError + + def pack_single_results(self, results: dict) -> dict: + """Method to pack the single input data. when the value in this dict is + a list, it usually is in Augmentations Testing. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: A dict contains + + - 'inputs' (dict): The forward data of models. It usually contains + following keys: + + - points + - img + + - 'data_samples' (:obj:`NeRFDet3DDataSample`): The annotation info + of the sample. + """ + # Format 3D data + if 'points' in results: + if isinstance(results['points'], BasePoints): + results['points'] = results['points'].tensor + + if 'img' in results: + if isinstance(results['img'], list): + # process multiple imgs in single frame + imgs = np.stack(results['img'], axis=0) + if imgs.flags.c_contiguous: + imgs = to_tensor(imgs).permute(0, 3, 1, 2).contiguous() + else: + imgs = to_tensor( + np.ascontiguousarray(imgs.transpose(0, 3, 1, 2))) + results['img'] = imgs + else: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + # To improve the computational speed by by 3-5 times, apply: + # `torch.permute()` rather than `np.transpose()`. + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + if img.flags.c_contiguous: + img = to_tensor(img).permute(2, 0, 1).contiguous() + else: + img = to_tensor( + np.ascontiguousarray(img.transpose(2, 0, 1))) + results['img'] = img + + if 'depth' in results: + if isinstance(results['depth'], list): + # process multiple depth imgs in single frame + depth_imgs = np.stack(results['depth'], axis=0) + if depth_imgs.flags.c_contiguous: + depth_imgs = to_tensor(depth_imgs).contiguous() + else: + depth_imgs = to_tensor(np.ascontiguousarray(depth_imgs)) + results['depth'] = depth_imgs + else: + depth_img = results['depth'] + if len(depth_img.shape) < 3: + depth_img = np.expand_dims(depth_img, -1) + if depth_img.flags.c_contiguous: + depth_img = to_tensor(depth_img).contiguous() + else: + depth_img = to_tensor(np.ascontiguousarray(depth_img)) + results['depth'] = depth_img + + if 'ray_info' in results: + if isinstance(results['raydirs'], list): + raydirs = np.stack(results['raydirs'], axis=0) + if raydirs.flags.c_contiguous: + raydirs = to_tensor(raydirs).contiguous() + else: + raydirs = to_tensor(np.ascontiguousarray(raydirs)) + results['raydirs'] = raydirs + + if isinstance(results['lightpos'], list): + lightposes = np.stack(results['lightpos'], axis=0) + if lightposes.flags.c_contiguous: + lightposes = to_tensor(lightposes).contiguous() + else: + lightposes = to_tensor(np.ascontiguousarray(lightposes)) + lightposes = lightposes.unsqueeze(1).repeat( + 1, raydirs.shape[1], 1) + results['lightpos'] = lightposes + + if isinstance(results['gt_images'], list): + gt_images = np.stack(results['gt_images'], axis=0) + if gt_images.flags.c_contiguous: + gt_images = to_tensor(gt_images).contiguous() + else: + gt_images = to_tensor(np.ascontiguousarray(gt_images)) + results['gt_images'] = gt_images + + if isinstance(results['gt_depths'], + list) and len(results['gt_depths']) != 0: + gt_depths = np.stack(results['gt_depths'], axis=0) + if gt_depths.flags.c_contiguous: + gt_depths = to_tensor(gt_depths).contiguous() + else: + gt_depths = to_tensor(np.ascontiguousarray(gt_depths)) + results['gt_depths'] = gt_depths + + if isinstance(results['denorm_images'], list): + denorm_imgs = np.stack(results['denorm_images'], axis=0) + if denorm_imgs.flags.c_contiguous: + denorm_imgs = to_tensor(denorm_imgs).permute( + 0, 3, 1, 2).contiguous() + else: + denorm_imgs = to_tensor( + np.ascontiguousarray( + denorm_imgs.transpose(0, 3, 1, 2))) + results['denorm_images'] = denorm_imgs + + for key in [ + 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', + 'gt_bboxes_labels', 'attr_labels', 'pts_instance_mask', + 'pts_semantic_mask', 'centers_2d', 'depths', 'gt_labels_3d' + ]: + if key not in results: + continue + if isinstance(results[key], list): + results[key] = [to_tensor(res) for res in results[key]] + else: + results[key] = to_tensor(results[key]) + if 'gt_bboxes_3d' in results: + if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes): + results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d']) + + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = to_tensor( + results['gt_semantic_seg'][None]) + if 'gt_seg_map' in results: + results['gt_seg_map'] = results['gt_seg_map'][None, ...] + + if 'gt_images' in results: + results['gt_images'] = to_tensor(results['gt_images']) + if 'gt_depths' in results: + results['gt_depths'] = to_tensor(results['gt_depths']) + + data_sample = NeRFDet3DDataSample() + gt_instances_3d = InstanceData() + gt_instances = InstanceData() + gt_pts_seg = PointData() + gt_nerf_images = InstanceData() + gt_nerf_depths = InstanceData() + + data_metas = {} + for key in self.meta_keys: + if key in results: + data_metas[key] = results[key] + elif 'images' in results: + if len(results['images'].keys()) == 1: + cam_type = list(results['images'].keys())[0] + # single-view image + if key in results['images'][cam_type]: + data_metas[key] = results['images'][cam_type][key] + else: + # multi-view image + img_metas = [] + cam_types = list(results['images'].keys()) + for cam_type in cam_types: + if key in results['images'][cam_type]: + img_metas.append(results['images'][cam_type][key]) + if len(img_metas) > 0: + data_metas[key] = img_metas + elif 'lidar_points' in results: + if key in results['lidar_points']: + data_metas[key] = results['lidar_points'][key] + data_sample.set_metainfo(data_metas) + + inputs = {} + for key in self.keys: + if key in results: + # if key in self.INPUTS_KEYS: + if key in self.NERF_INPUT_KEYS: + inputs[key] = results[key] + elif key in self.NERF_3D_KEYS: + if key == 'gt_images': + gt_nerf_images[self._remove_prefix(key)] = results[key] + else: + gt_nerf_depths[self._remove_prefix(key)] = results[key] + elif key in self.INSTANCEDATA_3D_KEYS: + gt_instances_3d[self._remove_prefix(key)] = results[key] + elif key in self.INSTANCEDATA_2D_KEYS: + if key == 'gt_bboxes_labels': + gt_instances['labels'] = results[key] + else: + gt_instances[self._remove_prefix(key)] = results[key] + elif key in self.SEG_KEYS: + gt_pts_seg[self._remove_prefix(key)] = results[key] + else: + raise NotImplementedError(f'Please modified ' + f'`Pack3DDetInputs` ' + f'to put {key} to ' + f'corresponding field') + + data_sample.gt_instances_3d = gt_instances_3d + data_sample.gt_instances = gt_instances + data_sample.gt_pts_seg = gt_pts_seg + data_sample.gt_nerf_images = gt_nerf_images + data_sample.gt_nerf_depths = gt_nerf_depths + if 'eval_ann_info' in results: + data_sample.eval_ann_info = results['eval_ann_info'] + else: + data_sample.eval_ann_info = None + + packed_results = dict() + packed_results['data_samples'] = data_sample + packed_results['inputs'] = inputs + + return packed_results + + def __repr__(self) -> str: + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f'(keys={self.keys})' + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/projects/NeRF-Det/nerfdet/multiview_pipeline.py b/projects/NeRF-Det/nerfdet/multiview_pipeline.py new file mode 100644 index 0000000000..23e84ed71f --- /dev/null +++ b/projects/NeRF-Det/nerfdet/multiview_pipeline.py @@ -0,0 +1,297 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform, Compose +from PIL import Image + +from mmdet3d.registry import TRANSFORMS + + +def get_dtu_raydir(pixelcoords, intrinsic, rot, dir_norm=None): + # rot is c2w + # pixelcoords: H x W x 2 + x = (pixelcoords[..., 0] + 0.5 - intrinsic[0, 2]) / intrinsic[0, 0] + y = (pixelcoords[..., 1] + 0.5 - intrinsic[1, 2]) / intrinsic[1, 1] + z = np.ones_like(x) + dirs = np.stack([x, y, z], axis=-1) + # dirs = np.sum(dirs[...,None,:] * rot[:,:], axis=-1) # h*w*1*3 x 3*3 + dirs = dirs @ rot[:, :].T # + if dir_norm: + dirs = dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5) + + return dirs + + +@TRANSFORMS.register_module() +class MultiViewPipeline(BaseTransform): + """MultiViewPipeline used in nerfdet. + + Required Keys: + + - depth_info + - img_prefix + - img_info + - lidar2img + - c2w + - cammrotc2w + - lightpos + - ray_info + + Modified Keys: + + - lidar2img + + Added Keys: + + - img + - denorm_images + - depth + - c2w + - camrotc2w + - lightpos + - pixels + - raydirs + - gt_images + - gt_depths + - nerf_sizes + - depth_range + + Args: + transforms (list[dict]): The transform pipeline + used to process the imgs. + n_images (int): The number of sampled views. + mean (array): The mean values used in normalization. + std (array): The variance values used in normalization. + margin (int): The margin value. Defaults to 10. + depth_range (array): The range of the depth. + Defaults to [0.5, 5.5]. + loading (str): The mode of loading. Defaults to 'random'. + nerf_target_views (int): The number of novel views. + sample_freq (int): The frequency of sampling. + """ + + def __init__(self, + transforms: dict, + n_images: int, + mean: tuple = [123.675, 116.28, 103.53], + std: tuple = [58.395, 57.12, 57.375], + margin: int = 10, + depth_range: tuple = [0.5, 5.5], + loading: str = 'random', + nerf_target_views: int = 0, + sample_freq: int = 3): + self.transforms = Compose(transforms) + self.depth_transforms = Compose(transforms[1]) + self.n_images = n_images + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.margin = margin + self.depth_range = depth_range + self.loading = loading + self.sample_freq = sample_freq + self.nerf_target_views = nerf_target_views + + def transform(self, results: dict) -> dict: + """Nerfdet transform function. + + Args: + results (dict): Result dict from loading pipeline + + Returns: + dict: The result dict containing the processed results. + Updated key and value are described below. + + - img (list): The loaded origin image. + - denorm_images (list): The denormalized image. + - depth (list): The origin depth image. + - c2w (list): The c2w matrixes. + - camrotc2w (list): The rotation matrixes. + - lightpos (list): The transform parameters of the camera. + - pixels (list): Some pixel information. + - raydirs (list): The ray-directions. + - gt_images (list): The groundtruth images. + - gt_depths (list): The groundtruth depth images. + - nerf_sizes (array): The size of the groundtruth images. + - depth_range (array): The range of the depth. + + Here we give a detailed explanation of some keys mentioned above. + Let P_c be the coordinate of camera, P_w be the coordinate of world. + There is such a conversion relationship: P_c = R @ P_w + T. + The 'camrotc2w' mentioned above corresponds to the R matrix here. + The 'lightpos' corresponds to the T matrix here. And if you put + R and T together, you can get the camera extrinsics matrix. It + corresponds to the 'c2w' mentioned above. + """ + imgs = [] + depths = [] + extrinsics = [] + c2ws = [] + camrotc2ws = [] + lightposes = [] + pixels = [] + raydirs = [] + gt_images = [] + gt_depths = [] + denorm_imgs_list = [] + nerf_sizes = [] + + if self.loading == 'random': + ids = np.arange(len(results['img_info'])) + replace = True if self.n_images > len(ids) else False + ids = np.random.choice(ids, self.n_images, replace=replace) + if self.nerf_target_views != 0: + target_id = np.random.choice( + ids, self.nerf_target_views, replace=False) + ids = np.setdiff1d(ids, target_id) + ids = ids.tolist() + target_id = target_id.tolist() + + else: + ids = np.arange(len(results['img_info'])) + begin_id = 0 + ids = np.arange(begin_id, + begin_id + self.n_images * self.sample_freq, + self.sample_freq) + if self.nerf_target_views != 0: + target_id = ids + + ratio = 0 + size = (240, 320) + for i in ids: + _results = dict() + _results['img_path'] = results['img_info'][i]['filename'] + _results = self.transforms(_results) + imgs.append(_results['img']) + # normalize + for key in _results.get('img_fields', ['img']): + _results[key] = mmcv.imnormalize(_results[key], self.mean, + self.std, True) + _results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=True) + # pad + for key in _results.get('img_fields', ['img']): + padded_img = mmcv.impad(_results[key], shape=size, pad_val=0) + _results[key] = padded_img + _results['pad_shape'] = padded_img.shape + _results['pad_fixed_size'] = size + ori_shape = _results['ori_shape'] + aft_shape = _results['img_shape'] + ratio = ori_shape[0] / aft_shape[0] + # prepare the depth information + if 'depth_info' in results.keys(): + if '.npy' in results['depth_info'][i]['filename']: + _results['depth'] = np.load( + results['depth_info'][i]['filename']) + else: + _results['depth'] = np.asarray((Image.open( + results['depth_info'][i]['filename']))) / 1000 + _results['depth'] = mmcv.imresize( + _results['depth'], (aft_shape[1], aft_shape[0])) + depths.append(_results['depth']) + + denorm_img = mmcv.imdenormalize( + _results['img'], self.mean, self.std, to_bgr=True).astype( + np.uint8) / 255.0 + denorm_imgs_list.append(denorm_img) + height, width = padded_img.shape[:2] + extrinsics.append(results['lidar2img']['extrinsic'][i]) + + # prepare the nerf information + if 'ray_info' in results.keys(): + intrinsics_nerf = results['lidar2img']['intrinsic'].copy() + intrinsics_nerf[:2] = intrinsics_nerf[:2] / ratio + assert self.nerf_target_views > 0 + for i in target_id: + c2ws.append(results['c2w'][i]) + camrotc2ws.append(results['camrotc2w'][i]) + lightposes.append(results['lightpos'][i]) + px, py = np.meshgrid( + np.arange(self.margin, + width - self.margin).astype(np.float32), + np.arange(self.margin, + height - self.margin).astype(np.float32)) + pixelcoords = np.stack((px, py), + axis=-1).astype(np.float32) # H x W x 2 + pixels.append(pixelcoords) + raydir = get_dtu_raydir(pixelcoords, intrinsics_nerf, + results['camrotc2w'][i]) + raydirs.append(np.reshape(raydir.astype(np.float32), (-1, 3))) + # read target images + temp_results = dict() + temp_results['img_path'] = results['img_info'][i]['filename'] + + temp_results_ = self.transforms(temp_results) + # normalize + for key in temp_results.get('img_fields', ['img']): + temp_results[key] = mmcv.imnormalize( + temp_results[key], self.mean, self.std, True) + temp_results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=True) + # pad + for key in temp_results.get('img_fields', ['img']): + padded_img = mmcv.impad( + temp_results[key], shape=size, pad_val=0) + temp_results[key] = padded_img + temp_results['pad_shape'] = padded_img.shape + temp_results['pad_fixed_size'] = size + # denormalize target_images. + denorm_imgs = mmcv.imdenormalize( + temp_results_['img'], self.mean, self.std, + to_bgr=True).astype(np.uint8) + gt_rgb_shape = denorm_imgs.shape + + gt_image = denorm_imgs[py.astype(np.int32), + px.astype(np.int32), :] + nerf_sizes.append(np.array(gt_image.shape)) + gt_image = np.reshape(gt_image, (-1, 3)) + gt_images.append(gt_image / 255.0) + if 'depth_info' in results.keys(): + if '.npy' in results['depth_info'][i]['filename']: + _results['depth'] = np.load( + results['depth_info'][i]['filename']) + else: + depth_image = Image.open( + results['depth_info'][i]['filename']) + _results['depth'] = np.asarray(depth_image) / 1000 + _results['depth'] = mmcv.imresize( + _results['depth'], + (gt_rgb_shape[1], gt_rgb_shape[0])) + + _results['depth'] = _results['depth'] + gt_depth = _results['depth'][py.astype(np.int32), + px.astype(np.int32)] + gt_depths.append(gt_depth) + + for key in _results.keys(): + if key not in ['img', 'img_info']: + results[key] = _results[key] + results['img'] = imgs + + if 'ray_info' in results.keys(): + results['c2w'] = c2ws + results['camrotc2w'] = camrotc2ws + results['lightpos'] = lightposes + results['pixels'] = pixels + results['raydirs'] = raydirs + results['gt_images'] = gt_images + results['gt_depths'] = gt_depths + results['nerf_sizes'] = nerf_sizes + results['denorm_images'] = denorm_imgs_list + results['depth_range'] = np.array([self.depth_range]) + + if len(depths) != 0: + results['depth'] = depths + results['lidar2img']['extrinsic'] = extrinsics + return results + + +@TRANSFORMS.register_module() +class RandomShiftOrigin(BaseTransform): + + def __init__(self, std): + self.std = std + + def transform(self, results): + shift = np.random.normal(.0, self.std, 3) + results['lidar2img']['origin'] += shift + return results diff --git a/projects/NeRF-Det/nerfdet/nerf_det3d_data_sample.py b/projects/NeRF-Det/nerfdet/nerf_det3d_data_sample.py new file mode 100644 index 0000000000..439e9a69ba --- /dev/null +++ b/projects/NeRF-Det/nerfdet/nerf_det3d_data_sample.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import torch +from mmengine.structures import InstanceData + +from mmdet3d.structures import Det3DDataSample + + +class NeRFDet3DDataSample(Det3DDataSample): + """A data structure interface inheirted from Det3DDataSample. Some new + attributes are added to match the NeRF-Det project. + + The attributes added in ``NeRFDet3DDataSample`` are divided into two parts: + + - ``gt_nerf_images`` (InstanceData): Ground truth of the images which + will be used in the NeRF branch. + - ``gt_nerf_depths`` (InstanceData): Ground truth of the depth images + which will be used in the NeRF branch if needed. + + For more details and examples, please refer to the 'Det3DDataSample' file. + """ + + @property + def gt_nerf_images(self) -> InstanceData: + return self._gt_nerf_images + + @gt_nerf_images.setter + def gt_nerf_images(self, value: InstanceData) -> None: + self.set_field(value, '_gt_nerf_images', dtype=InstanceData) + + @gt_nerf_images.deleter + def gt_nerf_images(self) -> None: + del self._gt_nerf_images + + @property + def gt_nerf_depths(self) -> InstanceData: + return self._gt_nerf_depths + + @gt_nerf_depths.setter + def gt_nerf_depths(self, value: InstanceData) -> None: + self.set_field(value, '_gt_nerf_depths', dtype=InstanceData) + + @gt_nerf_depths.deleter + def gt_nerf_depths(self) -> None: + del self._gt_nerf_depths + + +SampleList = List[NeRFDet3DDataSample] +OptSampleList = Optional[SampleList] +ForwardResults = Union[Dict[str, torch.Tensor], List[NeRFDet3DDataSample], + Tuple[torch.Tensor], torch.Tensor] diff --git a/projects/NeRF-Det/nerfdet/nerf_utils/nerf_mlp.py b/projects/NeRF-Det/nerfdet/nerf_utils/nerf_mlp.py new file mode 100644 index 0000000000..cc579ea23b --- /dev/null +++ b/projects/NeRF-Det/nerfdet/nerf_utils/nerf_mlp.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + """The MLP module used in NerfDet. + + Args: + input_dim (int): The number of input tensor channels. + output_dim (int): The number of output tensor channels. + net_depth (int): The depth of the MLP. Defaults to 8. + net_width (int): The width of the MLP. Defaults to 256. + skip_layer (int): The layer to add skip layers to. Defaults to 4. + + hidden_init (Callable): The initialize method of the hidden layers. + hidden_activation (Callable): The activation function of hidden + layers, defaults to ReLU. + output_enabled (bool): If true, the output layers will be used. + Defaults to True. + output_init (Optional): The initialize method of the output layer. + output_activation(Optional): The activation function of output layers. + bias_enabled (Bool): If true, the bias will be used. + bias_init (Callable): The initialize method of the bias. + Defaults to True. + """ + + def __init__( + self, + input_dim: int, + output_dim: int = None, + net_depth: int = 8, + net_width: int = 256, + skip_layer: int = 4, + hidden_init: Callable = nn.init.xavier_uniform_, + hidden_activation: Callable = nn.ReLU(), + output_enabled: bool = True, + output_init: Optional[Callable] = nn.init.xavier_uniform_, + output_activation: Optional[Callable] = nn.Identity(), + bias_enabled: bool = True, + bias_init: Callable = nn.init.zeros_, + ): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.net_depth = net_depth + self.net_width = net_width + self.skip_layer = skip_layer + self.hidden_init = hidden_init + self.hidden_activation = hidden_activation + self.output_enabled = output_enabled + self.output_init = output_init + self.output_activation = output_activation + self.bias_enabled = bias_enabled + self.bias_init = bias_init + + self.hidden_layers = nn.ModuleList() + in_features = self.input_dim + for i in range(self.net_depth): + self.hidden_layers.append( + nn.Linear(in_features, self.net_width, bias=bias_enabled)) + if (self.skip_layer is not None) and (i % self.skip_layer + == 0) and (i > 0): + in_features = self.net_width + self.input_dim + else: + in_features = self.net_width + if self.output_enabled: + self.output_layer = nn.Linear( + in_features, self.output_dim, bias=bias_enabled) + else: + self.output_dim = in_features + + self.initialize() + + def initialize(self): + + def init_func_hidden(m): + if isinstance(m, nn.Linear): + if self.hidden_init is not None: + self.hidden_init(m.weight) + if self.bias_enabled and self.bias_init is not None: + self.bias_init(m.bias) + + self.hidden_layers.apply(init_func_hidden) + if self.output_enabled: + + def init_func_output(m): + if isinstance(m, nn.Linear): + if self.output_init is not None: + self.output_init(m.weight) + if self.bias_enabled and self.bias_init is not None: + self.bias_init(m.bias) + + self.output_layer.apply(init_func_output) + + def forward(self, x): + inputs = x + for i in range(self.net_depth): + x = self.hidden_layers[i](x) + x = self.hidden_activation(x) + if (self.skip_layer is not None) and (i % self.skip_layer + == 0) and (i > 0): + x = torch.cat([x, inputs], dim=-1) + if self.output_enabled: + x = self.output_layer(x) + x = self.output_activation(x) + return x + + +class DenseLayer(MLP): + + def __init__(self, input_dim, output_dim, **kwargs): + super().__init__( + input_dim=input_dim, + output_dim=output_dim, + net_depth=0, # no hidden layers + **kwargs, + ) + + +class NerfMLP(nn.Module): + """The Nerf-MLP Module. + + Args: + input_dim (int): The number of input tensor channels. + condition_dim (int): The number of condition tensor channels. + feature_dim (int): The number of feature channels. Defaults to 0. + net_depth (int): The depth of the MLP. Defaults to 8. + net_width (int): The width of the MLP. Defaults to 256. + skip_layer (int): The layer to add skip layers to. Defaults to 4. + net_depth_condition (int): The depth of the second part of MLP. + Defaults to 1. + net_width_condition (int): The width of the second part of MLP. + Defaults to 128. + """ + + def __init__( + self, + input_dim: int, + condition_dim: int, + feature_dim: int = 0, + net_depth: int = 8, + net_width: int = 256, + skip_layer: int = 4, + net_depth_condition: int = 1, + net_width_condition: int = 128, + ): + super().__init__() + self.base = MLP( + input_dim=input_dim + feature_dim, + net_depth=net_depth, + net_width=net_width, + skip_layer=skip_layer, + output_enabled=False, + ) + hidden_features = self.base.output_dim + self.sigma_layer = DenseLayer(hidden_features, 1) + + if condition_dim > 0: + self.bottleneck_layer = DenseLayer(hidden_features, net_width) + self.rgb_layer = MLP( + input_dim=net_width + condition_dim, + output_dim=3, + net_depth=net_depth_condition, + net_width=net_width_condition, + skip_layer=None, + ) + else: + self.rgb_layer = DenseLayer(hidden_features, 3) + + def query_density(self, x, features=None): + """Calculate the raw sigma.""" + if features is not None: + x = self.base(torch.cat([x, features], dim=-1)) + else: + x = self.base(x) + raw_sigma = self.sigma_layer(x) + return raw_sigma + + def forward(self, x, condition=None, features=None): + if features is not None: + x = self.base(torch.cat([x, features], dim=-1)) + else: + x = self.base(x) + raw_sigma = self.sigma_layer(x) + if condition is not None: + if condition.shape[:-1] != x.shape[:-1]: + num_rays, n_dim = condition.shape + condition = condition.view( + [num_rays] + [1] * (x.dim() - condition.dim()) + + [n_dim]).expand(list(x.shape[:-1]) + [n_dim]) + bottleneck = self.bottleneck_layer(x) + x = torch.cat([bottleneck, condition], dim=-1) + raw_rgb = self.rgb_layer(x) + return raw_rgb, raw_sigma + + +class SinusoidalEncoder(nn.Module): + """Sinusodial Positional Encoder used in NeRF.""" + + def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True): + super().__init__() + self.x_dim = x_dim + self.min_deg = min_deg + self.max_deg = max_deg + self.use_identity = use_identity + self.register_buffer( + 'scales', torch.tensor([2**i for i in range(min_deg, max_deg)])) + + @property + def latent_dim(self) -> int: + return (int(self.use_identity) + + (self.max_deg - self.min_deg) * 2) * self.x_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.max_deg == self.min_deg: + return x + xb = torch.reshape( + (x[Ellipsis, None, :] * self.scales[:, None]), + list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim], + ) + latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1)) + if self.use_identity: + latent = torch.cat([x] + [latent], dim=-1) + return latent + + +class VanillaNeRF(nn.Module): + """The Nerf-MLP with the positional encoder. + + Args: + net_depth (int): The depth of the MLP. Defaults to 8. + net_width (int): The width of the MLP. Defaults to 256. + skip_layer (int): The layer to add skip layers to. Defaults to 4. + feature_dim (int): The number of feature channels. Defaults to 0. + net_depth_condition (int): The depth of the second part of MLP. + Defaults to 1. + net_width_condition (int): The width of the second part of MLP. + Defaults to 128. + """ + + def __init__(self, + net_depth: int = 8, + net_width: int = 256, + skip_layer: int = 4, + feature_dim: int = 0, + net_depth_condition: int = 1, + net_width_condition: int = 128): + super().__init__() + self.posi_encoder = SinusoidalEncoder(3, 0, 10, True) + self.view_encoder = SinusoidalEncoder(3, 0, 4, True) + self.mlp = NerfMLP( + input_dim=self.posi_encoder.latent_dim, + condition_dim=self.view_encoder.latent_dim, + feature_dim=feature_dim, + net_depth=net_depth, + net_width=net_width, + skip_layer=skip_layer, + net_depth_condition=net_depth_condition, + net_width_condition=net_width_condition, + ) + + def query_density(self, x, features=None): + x = self.posi_encoder(x) + sigma = self.mlp.query_density(x, features) + return F.relu(sigma) + + def forward(self, x, condition=None, features=None): + x = self.posi_encoder(x) + if condition is not None: + condition = self.view_encoder(condition) + rgb, sigma = self.mlp(x, condition=condition, features=features) + return torch.sigmoid(rgb), F.relu(sigma) diff --git a/projects/NeRF-Det/nerfdet/nerf_utils/projection.py b/projects/NeRF-Det/nerfdet/nerf_utils/projection.py new file mode 100644 index 0000000000..d88e281420 --- /dev/null +++ b/projects/NeRF-Det/nerfdet/nerf_utils/projection.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Attention: This file is mainly modified based on the file with the same +# name in the original project. For more details, please refer to the +# origin project. +import torch +import torch.nn.functional as F + + +class Projector(): + + def __init__(self, device='cuda'): + self.device = device + + def inbound(self, pixel_locations, h, w): + """check if the pixel locations are in valid range.""" + return (pixel_locations[..., 0] <= w - 1.) & \ + (pixel_locations[..., 0] >= 0) & \ + (pixel_locations[..., 1] <= h - 1.) &\ + (pixel_locations[..., 1] >= 0) + + def normalize(self, pixel_locations, h, w): + resize_factor = torch.tensor([w - 1., h - 1. + ]).to(pixel_locations.device)[None, + None, :] + normalized_pixel_locations = 2 * pixel_locations / resize_factor - 1. + return normalized_pixel_locations + + def compute_projections(self, xyz, train_cameras): + """project 3D points into cameras.""" + + original_shape = xyz.shape[:2] + xyz = xyz.reshape(-1, 3) + num_views = len(train_cameras) + train_intrinsics = train_cameras[:, 2:18].reshape(-1, 4, 4) + train_poses = train_cameras[:, -16:].reshape(-1, 4, 4) + xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1) + # projections = train_intrinsics.bmm(torch.inverse(train_poses)) + # we have inverse the pose in dataloader so + # do not need to inverse here. + projections = train_intrinsics.bmm(train_poses) \ + .bmm(xyz_h.t()[None, ...].repeat(num_views, 1, 1)) + projections = projections.permute(0, 2, 1) + pixel_locations = projections[..., :2] / torch.clamp( + projections[..., 2:3], min=1e-8) + pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6) + mask = projections[..., 2] > 0 + return pixel_locations.reshape((num_views, ) + original_shape + (2, )), \ + mask.reshape((num_views, ) + original_shape) # noqa + + def compute_angle(self, xyz, query_camera, train_cameras): + + original_shape = xyz.shape[:2] + xyz = xyz.reshape(-1, 3) + train_poses = train_cameras[:, -16:].reshape(-1, 4, 4) + num_views = len(train_poses) + query_pose = query_camera[-16:].reshape(-1, 4, + 4).repeat(num_views, 1, 1) + ray2tar_pose = (query_pose[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6) + ray2train_pose = ( + train_poses[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2train_pose /= ( + torch.norm(ray2train_pose, dim=-1, keepdim=True) + 1e-6) + ray_diff = ray2tar_pose - ray2train_pose + ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) + ray_diff_dot = torch.sum( + ray2tar_pose * ray2train_pose, dim=-1, keepdim=True) + ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) + ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) + ray_diff = ray_diff.reshape((num_views, ) + original_shape + (4, )) + return ray_diff + + def compute(self, + xyz, + train_imgs, + train_cameras, + featmaps=None, + grid_sample=True): + + assert (train_imgs.shape[0] == 1) \ + and (train_cameras.shape[0] == 1) + # only support batch_size=1 for now + + train_imgs = train_imgs.squeeze(0) + train_cameras = train_cameras.squeeze(0) + + train_imgs = train_imgs.permute(0, 3, 1, 2) + h, w = train_cameras[0][:2] + + # compute the projection of the query points to each reference image + pixel_locations, mask_in_front = self.compute_projections( + xyz, train_cameras) + normalized_pixel_locations = self.normalize(pixel_locations, h, w) + # rgb sampling + rgbs_sampled = F.grid_sample( + train_imgs, normalized_pixel_locations, align_corners=True) + rgb_sampled = rgbs_sampled.permute(2, 3, 0, 1) + + # deep feature sampling + if featmaps is not None: + if grid_sample: + feat_sampled = F.grid_sample( + featmaps, normalized_pixel_locations, align_corners=True) + feat_sampled = feat_sampled.permute( + 2, 3, 0, 1) # [n_rays, n_samples, n_views, d] + rgb_feat_sampled = torch.cat( + [rgb_sampled, feat_sampled], + dim=-1) # [n_rays, n_samples, n_views, d+3] + # rgb_feat_sampled = feat_sampled + else: + n_images, n_channels, f_h, f_w = featmaps.shape + resize_factor = torch.tensor([f_w / w - 1., f_h / h - 1.]).to( + pixel_locations.device)[None, None, :] + sample_location = (pixel_locations * + resize_factor).round().long() + n_images, n_ray, n_sample, _ = sample_location.shape + sample_x = sample_location[..., 0].view(n_images, -1) + sample_y = sample_location[..., 1].view(n_images, -1) + valid = (sample_x >= 0) & (sample_y >= + 0) & (sample_x < f_w) & ( + sample_y < f_h) + valid = valid * mask_in_front.view(n_images, -1) + feat_sampled = torch.zeros( + (n_images, n_channels, sample_x.shape[-1]), + device=featmaps.device) + for i in range(n_images): + feat_sampled[i, :, + valid[i]] = featmaps[i, :, sample_y[i, + valid[i]], + sample_y[i, valid[i]]] + feat_sampled = feat_sampled.view(n_images, n_channels, n_ray, + n_sample) + rgb_feat_sampled = feat_sampled.permute(2, 3, 0, 1) + + else: + rgb_feat_sampled = None + inbound = self.inbound(pixel_locations, h, w) + mask = (inbound * mask_in_front).float().permute( + 1, 2, 0)[..., None] # [n_rays, n_samples, n_views, 1] + return rgb_feat_sampled, mask diff --git a/projects/NeRF-Det/nerfdet/nerf_utils/render_ray.py b/projects/NeRF-Det/nerfdet/nerf_utils/render_ray.py new file mode 100644 index 0000000000..76582c5773 --- /dev/null +++ b/projects/NeRF-Det/nerfdet/nerf_utils/render_ray.py @@ -0,0 +1,431 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Attention: This file is mainly modified based on the file with the same +# name in the original project. For more details, please refer to the +# origin project. +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F + +rng = np.random.RandomState(234) + + +# helper functions for nerf ray rendering +def volume_sampling(sample_pts, features, aabb): + B, C, D, W, H = features.shape + assert B == 1 + aabb = torch.Tensor(aabb).to(sample_pts.device) + N_rays, N_samples, coords = sample_pts.shape + sample_pts = sample_pts.view(1, N_rays * N_samples, 1, 1, + 3).repeat(B, 1, 1, 1, 1) + aabbSize = aabb[1] - aabb[0] + invgridSize = 1.0 / aabbSize * 2 + norm_pts = (sample_pts - aabb[0]) * invgridSize - 1 + sample_features = F.grid_sample( + features, norm_pts, align_corners=True, padding_mode='border') + masks = ((norm_pts < 1) & (norm_pts > -1)).float().sum(dim=-1) + masks = (masks.view(N_rays, N_samples) == 3) + return sample_features.view(C, N_rays, + N_samples).permute(1, 2, 0).contiguous(), masks + + +def _compute_projection(img_meta): + views = len(img_meta['lidar2img']['extrinsic']) + intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:4, :4]) + ratio = img_meta['ori_shape'][0] / img_meta['img_shape'][0] + intrinsic[:2] /= ratio + intrinsic = intrinsic.unsqueeze(0).view(1, 16).repeat(views, 1) + img_size = torch.Tensor(img_meta['img_shape'][:2]).to(intrinsic.device) + img_size = img_size.unsqueeze(0).repeat(views, 1) + extrinsics = [] + for v in range(views): + extrinsics.append( + torch.Tensor(img_meta['lidar2img']['extrinsic'][v]).to( + intrinsic.device)) + extrinsic = torch.stack(extrinsics).view(views, 16) + train_cameras = torch.cat([img_size, intrinsic, extrinsic], dim=-1) + return train_cameras.unsqueeze(0) + + +def compute_mask_points(feature, mask): + weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) + mean = torch.sum(feature * weight, dim=2, keepdim=True) + var = torch.sum((feature - mean)**2, dim=2, keepdim=True) + var = var / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) + var = torch.exp(-var) + return mean, var + + +def sample_pdf(bins, weights, N_samples, det=False): + """Helper function used for sampling. + + Args: + bins (tensor):Tensor of shape [N_rays, M+1], M is the number of bins + weights (tensor):Tensor of shape [N_rays, M+1], M is the number of bins + N_samples (int):Number of samples along each ray + det (bool):If True, will perform deterministic sampling + + Returns: + samples (tuple): [N_rays, N_samples] + """ + + M = weights.shape[1] + weights += 1e-5 + # Get pdf + pdf = weights / torch.sum(weights, dim=-1, keepdim=True) + cdf = torch.cumsum(pdf, dim=-1) + cdf = torch.cat([torch.zeros_like(cdf[:, 0:1]), cdf], dim=-1) + + # Take uniform samples + if det: + u = torch.linspace(0., 1., N_samples, device=bins.device) + u = u.unsqueeze(0).repeat(bins.shape[0], 1) + else: + u = torch.rand(bins.shape[0], N_samples, device=bins.device) + + # Invert CDF + above_inds = torch.zeros_like(u, dtype=torch.long) + for i in range(M): + above_inds += (u >= cdf[:, i:i + 1]).long() + + # random sample inside each bin + below_inds = torch.clamp(above_inds - 1, min=0) + inds_g = torch.stack((below_inds, above_inds), dim=2) + + cdf = cdf.unsqueeze(1).repeat(1, N_samples, 1) + cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) + + bins = bins.unsqueeze(1).repeat(1, N_samples, 1) + bins_g = torch.gather(input=bins, dim=-1, index=inds_g) + + denom = cdf_g[:, :, 1] - cdf_g[:, :, 0] + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[:, :, 0]) / denom + + samples = bins_g[:, :, 0] + t * (bins_g[:, :, 1] - bins_g[:, :, 0]) + + return samples + + +def sample_along_camera_ray(ray_o, + ray_d, + depth_range, + N_samples, + inv_uniform=False, + det=False): + """Sampling along the camera ray. + + Args: + ray_o (tensor): Origin of the ray in scene coordinate system; + tensor of shape [N_rays, 3] + ray_d (tensor): Homogeneous ray direction vectors in + scene coordinate system; tensor of shape [N_rays, 3] + depth_range (tuple): [near_depth, far_depth] + inv_uniform (bool): If True,uniformly sampling inverse depth. + det (bool): If True, will perform deterministic sampling. + Returns: + pts (tensor): Tensor of shape [N_rays, N_samples, 3] + z_vals (tensor): Tensor of shape [N_rays, N_samples] + """ + # will sample inside [near_depth, far_depth] + # assume the nearest possible depth is at least (min_ratio * depth) + near_depth_value = depth_range[0] + far_depth_value = depth_range[1] + assert near_depth_value > 0 and far_depth_value > 0 \ + and far_depth_value > near_depth_value + + near_depth = near_depth_value * torch.ones_like(ray_d[..., 0]) + + far_depth = far_depth_value * torch.ones_like(ray_d[..., 0]) + + if inv_uniform: + start = 1. / near_depth + step = (1. / far_depth - start) / (N_samples - 1) + inv_z_vals = torch.stack([start + i * step for i in range(N_samples)], + dim=1) + z_vals = 1. / inv_z_vals + else: + start = near_depth + step = (far_depth - near_depth) / (N_samples - 1) + z_vals = torch.stack([start + i * step for i in range(N_samples)], + dim=1) + + if not det: + # get intervals between samples + mids = .5 * (z_vals[:, 1:] + z_vals[:, :-1]) + upper = torch.cat([mids, z_vals[:, -1:]], dim=-1) + lower = torch.cat([z_vals[:, 0:1], mids], dim=-1) + # uniform samples in those intervals + t_rand = torch.rand_like(z_vals) + z_vals = lower + (upper - lower) * t_rand + + ray_d = ray_d.unsqueeze(1).repeat(1, N_samples, 1) + ray_o = ray_o.unsqueeze(1).repeat(1, N_samples, 1) + pts = z_vals.unsqueeze(2) * ray_d + ray_o # [N_rays, N_samples, 3] + return pts, z_vals + + +# ray rendering of nerf +def raw2outputs(raw, z_vals, mask, white_bkgd=False): + """Transform raw data to outputs: + + Args: + raw(tensor):Raw network output.Tensor of shape [N_rays, N_samples, 4] + z_vals(tensor):Depth of point samples along rays. + Tensor of shape [N_rays, N_samples] + ray_d(tensor):[N_rays, 3] + + Returns: + ret(dict): + -rgb(tensor):[N_rays, 3] + -depth(tensor):[N_rays,] + -weights(tensor):[N_rays,] + -depth_std(tensor):[N_rays,] + """ + rgb = raw[:, :, :3] # [N_rays, N_samples, 3] + sigma = raw[:, :, 3] # [N_rays, N_samples] + + # note: we did not use the intervals here, + # because in practice different scenes from COLMAP can have + # very different scales, and using interval can affect + # the model's generalization ability. + # Therefore we don't use the intervals for both training and evaluation. + sigma2alpha = lambda sigma, dists: 1. - torch.exp(-sigma) # noqa + + # point samples are ordered with increasing depth + # interval between samples + dists = z_vals[:, 1:] - z_vals[:, :-1] + dists = torch.cat((dists, dists[:, -1:]), dim=-1) + + alpha = sigma2alpha(sigma, dists) + + T = torch.cumprod(1. - alpha + 1e-10, dim=-1)[:, :-1] + T = torch.cat((torch.ones_like(T[:, 0:1]), T), dim=-1) + + # maths show weights, and summation of weights along a ray, + # are always inside [0, 1] + weights = alpha * T + rgb_map = torch.sum(weights.unsqueeze(2) * rgb, dim=1) + + if white_bkgd: + rgb_map = rgb_map + (1. - torch.sum(weights, dim=-1, keepdim=True)) + + if mask is not None: + mask = mask.float().sum(dim=1) > 8 + + depth_map = torch.sum( + weights * z_vals, dim=-1) / ( + torch.sum(weights, dim=-1) + 1e-8) + depth_map = torch.clamp(depth_map, z_vals.min(), z_vals.max()) + + ret = OrderedDict([('rgb', rgb_map), ('depth', depth_map), + ('weights', weights), ('mask', mask), ('alpha', alpha), + ('z_vals', z_vals), ('transparency', T)]) + + return ret + + +def render_rays_func( + ray_o, + ray_d, + mean_volume, + cov_volume, + features_2D, + img, + aabb, + near_far_range, + N_samples, + N_rand=4096, + nerf_mlp=None, + img_meta=None, + projector=None, + mode='volume', # volume and image + nerf_sample_view=3, + inv_uniform=False, + N_importance=0, + det=False, + is_train=True, + white_bkgd=False, + gt_rgb=None, + gt_depth=None): + + ret = { + 'outputs_coarse': None, + 'outputs_fine': None, + 'gt_rgb': gt_rgb, + 'gt_depth': gt_depth + } + + # pts: [N_rays, N_samples, 3] + # z_vals: [N_rays, N_samples] + pts, z_vals = sample_along_camera_ray( + ray_o=ray_o, + ray_d=ray_d, + depth_range=near_far_range, + N_samples=N_samples, + inv_uniform=inv_uniform, + det=det) + N_rays, N_samples = pts.shape[:2] + + if mode == 'image': + img = img.permute(0, 2, 3, 1).unsqueeze(0) + train_camera = _compute_projection(img_meta).to(img.device) + rgb_feat, mask = projector.compute( + pts, img, train_camera, features_2D, grid_sample=True) + pixel_mask = mask[..., 0].sum(dim=2) > 1 + mean, var = compute_mask_points(rgb_feat, mask) + globalfeat = torch.cat([mean, var], dim=-1).squeeze(2) + rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalfeat) + raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1) + ret['sigma'] = density_pts + + elif mode == 'volume': + mean_pts, inbound_masks = volume_sampling(pts, mean_volume, aabb) + cov_pts, inbound_masks = volume_sampling(pts, cov_volume, aabb) + # This masks is for indicating which points outside of aabb + img = img.permute(0, 2, 3, 1).unsqueeze(0) + train_camera = _compute_projection(img_meta).to(img.device) + _, view_mask = projector.compute(pts, img, train_camera, None) + pixel_mask = view_mask[..., 0].sum(dim=2) > 1 + # plot_3D_vis(pts, aabb, img, train_camera) + # [N_rays, N_samples], should at least have 2 observations + # This mask is for indicating which points do not have projected point + globalpts = torch.cat([mean_pts, cov_pts], dim=-1) + rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalpts) + density_pts = density_pts * inbound_masks.unsqueeze(dim=-1) + + raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1) + + outputs_coarse = raw2outputs( + raw_coarse, z_vals, pixel_mask, white_bkgd=white_bkgd) + ret['outputs_coarse'] = outputs_coarse + + return ret + + +def render_rays( + ray_batch, + mean_volume, + cov_volume, + features_2D, + img, + aabb, + near_far_range, + N_samples, + N_rand=4096, + nerf_mlp=None, + img_meta=None, + projector=None, + mode='volume', # volume and image + nerf_sample_view=3, + inv_uniform=False, + N_importance=0, + det=False, + is_train=True, + white_bkgd=False, + render_testing=False): + """The function of the nerf rendering.""" + + ray_o = ray_batch['ray_o'] + ray_d = ray_batch['ray_d'] + gt_rgb = ray_batch['gt_rgb'] + gt_depth = ray_batch['gt_depth'] + nerf_sizes = ray_batch['nerf_sizes'] + if is_train: + ray_o = ray_o.view(-1, 3) + ray_d = ray_d.view(-1, 3) + gt_rgb = gt_rgb.view(-1, 3) + if gt_depth.shape[1] != 0: + gt_depth = gt_depth.view(-1, 1) + non_zero_depth = (gt_depth > 0).squeeze(-1) + ray_o = ray_o[non_zero_depth] + ray_d = ray_d[non_zero_depth] + gt_rgb = gt_rgb[non_zero_depth] + gt_depth = gt_depth[non_zero_depth] + else: + gt_depth = None + total_rays = ray_d.shape[0] + select_inds = rng.choice(total_rays, size=(N_rand, ), replace=False) + ray_o = ray_o[select_inds] + ray_d = ray_d[select_inds] + gt_rgb = gt_rgb[select_inds] + if gt_depth is not None: + gt_depth = gt_depth[select_inds] + + rets = render_rays_func( + ray_o, + ray_d, + mean_volume, + cov_volume, + features_2D, + img, + aabb, + near_far_range, + N_samples, + N_rand, + nerf_mlp, + img_meta, + projector, + mode, # volume and image + nerf_sample_view, + inv_uniform, + N_importance, + det, + is_train, + white_bkgd, + gt_rgb, + gt_depth) + + elif render_testing: + nerf_size = nerf_sizes[0] + view_num = ray_o.shape[1] + H = nerf_size[0][0] + W = nerf_size[0][1] + ray_o = ray_o.view(-1, 3) + ray_d = ray_d.view(-1, 3) + gt_rgb = gt_rgb.view(-1, 3) + print(gt_rgb.shape) + if len(gt_depth) != 0: + gt_depth = gt_depth.view(-1, 1) + else: + gt_depth = None + assert view_num * H * W == ray_o.shape[0] + num_rays = ray_o.shape[0] + results = [] + rgbs = [] + for i in range(0, num_rays, N_rand): + ray_o_chunck = ray_o[i:i + N_rand, :] + ray_d_chunck = ray_d[i:i + N_rand, :] + + ret = render_rays_func(ray_o_chunck, ray_d_chunck, mean_volume, + cov_volume, features_2D, img, aabb, + near_far_range, N_samples, N_rand, nerf_mlp, + img_meta, projector, mode, nerf_sample_view, + inv_uniform, N_importance, True, is_train, + white_bkgd, gt_rgb, gt_depth) + results.append(ret) + + rgbs = [] + depths = [] + + if results[0]['outputs_coarse'] is not None: + for i in range(len(results)): + rgb = results[i]['outputs_coarse']['rgb'] + rgbs.append(rgb) + depth = results[i]['outputs_coarse']['depth'] + depths.append(depth) + + rets = { + 'outputs_coarse': { + 'rgb': torch.cat(rgbs, dim=0).view(view_num, H, W, 3), + 'depth': torch.cat(depths, dim=0).view(view_num, H, W, 1), + }, + 'gt_rgb': + gt_rgb.view(view_num, H, W, 3), + 'gt_depth': + gt_depth.view(view_num, H, W, 1) if gt_depth is not None else None, + } + else: + rets = None + return rets diff --git a/projects/NeRF-Det/nerfdet/nerf_utils/save_rendered_img.py b/projects/NeRF-Det/nerfdet/nerf_utils/save_rendered_img.py new file mode 100644 index 0000000000..f9de3e3107 --- /dev/null +++ b/projects/NeRF-Det/nerfdet/nerf_utils/save_rendered_img.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import cv2 +import numpy as np +import torch +from skimage.metrics import structural_similarity + + +def compute_psnr_from_mse(mse): + return -10.0 * torch.log(mse) / np.log(10.0) + + +def compute_psnr(pred, target, mask=None): + """Compute psnr value (we assume the maximum pixel value is 1).""" + if mask is not None: + pred, target = pred[mask], target[mask] + mse = ((pred - target)**2).mean() + return compute_psnr_from_mse(mse).cpu().numpy() + + +def compute_ssim(pred, target, mask=None): + """Computes Masked SSIM following the neuralbody paper.""" + assert pred.shape == target.shape and pred.shape[-1] == 3 + if mask is not None: + x, y, w, h = cv2.boundingRect(mask.cpu().numpy().astype(np.uint8)) + pred = pred[y:y + h, x:x + w] + target = target[y:y + h, x:x + w] + try: + ssim = structural_similarity( + pred.cpu().numpy(), target.cpu().numpy(), channel_axis=-1) + except ValueError: + ssim = structural_similarity( + pred.cpu().numpy(), target.cpu().numpy(), multichannel=True) + return ssim + + +def save_rendered_img(img_meta, rendered_results): + filename = img_meta[0]['filename'] + scenes = filename.split('/')[-2] + + for ret in rendered_results: + depth = ret['outputs_coarse']['depth'] + rgb = ret['outputs_coarse']['rgb'] + gt = ret['gt_rgb'] + gt_depth = ret['gt_depth'] + + # save images + psnr_total = 0 + ssim_total = 0 + rsme = 0 + for v in range(gt.shape[0]): + rsme += ((depth[v] - gt_depth[v])**2).cpu().numpy() + depth_ = ((depth[v] - depth[v].min()) / + (depth[v].max() - depth[v].min() + 1e-8)).repeat(1, 1, 3) + img_to_save = torch.cat([rgb[v], gt[v], depth_], dim=1) + image_path = os.path.join('nerf_vs_rebuttal', scenes) + if not os.path.exists(image_path): + os.makedirs(image_path) + save_dir = os.path.join(image_path, 'view_' + str(v) + '.png') + + font = cv2.FONT_HERSHEY_SIMPLEX + org = (50, 50) + fontScale = 1 + color = (255, 0, 0) + thickness = 2 + image = np.uint8(img_to_save.cpu().numpy() * 255.0) + psnr = compute_psnr(rgb[v], gt[v], mask=None) + psnr_total += psnr + ssim = compute_ssim(rgb[v], gt[v], mask=None) + ssim_total += ssim + image = cv2.putText( + image, 'PSNR: ' + '%.2f' % compute_psnr(rgb[v], gt[v], mask=None), + org, font, fontScale, color, thickness, cv2.LINE_AA) + + cv2.imwrite(save_dir, image) + + return psnr_total / gt.shape[0], ssim_total / gt.shape[0], rsme / gt.shape[ + 0] diff --git a/projects/NeRF-Det/nerfdet/nerfdet.py b/projects/NeRF-Det/nerfdet/nerfdet.py new file mode 100644 index 0000000000..ee66387cb5 --- /dev/null +++ b/projects/NeRF-Det/nerfdet/nerfdet.py @@ -0,0 +1,632 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet3d.models.detectors import Base3DDetector +from mmdet3d.registry import MODELS, TASK_UTILS +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils import ConfigType, OptConfigType +from .nerf_utils.nerf_mlp import VanillaNeRF +from .nerf_utils.projection import Projector +from .nerf_utils.render_ray import render_rays + +# from ..utils.nerf_utils.save_rendered_img import save_rendered_img + + +@MODELS.register_module() +class NerfDet(Base3DDetector): + r"""`ImVoxelNet `_. + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + neck_3d(:obj:`ConfigDict` or dict): The 3D neck config. + bbox_head(:obj:`ConfigDict` or dict): The bbox head config. + prior_generator (:obj:`ConfigDict` or dict): The prior generator + config. + n_voxels (list): Number of voxels along x, y, z axis. + voxel_size (list): The size of voxels.Each voxel represents + a cube of `voxel_size[0]` meters, `voxel_size[1]` meters, + `` + train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of + training hyper-parameters. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test + hyper-parameters. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): The initialization + config. Defaults to None. + render_testing (bool): If you want to render novel view, please set + "render_testing = True" in config + The other args are the parameters of NeRF, you can just use the + default values. + """ + + def __init__( + self, + backbone: ConfigType, + neck: ConfigType, + neck_3d: ConfigType, + bbox_head: ConfigType, + prior_generator: ConfigType, + n_voxels: List, + voxel_size: List, + head_2d: ConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptConfigType = None, + # pretrained, + aabb: Tuple = None, + near_far_range: List = None, + N_samples: int = 64, + N_rand: int = 2048, + depth_supervise: bool = False, + use_nerf_mask: bool = True, + nerf_sample_view: int = 3, + nerf_mode: str = 'volume', + squeeze_scale: int = 4, + rgb_supervision: bool = True, + nerf_density: bool = False, + render_testing: bool = False): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.neck = MODELS.build(neck) + self.neck_3d = MODELS.build(neck_3d) + bbox_head.update(train_cfg=train_cfg) + bbox_head.update(test_cfg=test_cfg) + self.bbox_head = MODELS.build(bbox_head) + self.head_2d = MODELS.build(head_2d) if head_2d is not None else None + self.n_voxels = n_voxels + self.prior_generator = TASK_UTILS.build(prior_generator) + self.voxel_size = voxel_size + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.aabb = aabb + self.near_far_range = near_far_range + self.N_samples = N_samples + self.N_rand = N_rand + self.depth_supervise = depth_supervise + self.projector = Projector() + self.squeeze_scale = squeeze_scale + self.use_nerf_mask = use_nerf_mask + self.rgb_supervision = rgb_supervision + nerf_feature_dim = neck['out_channels'] // squeeze_scale + self.nerf_mlp = VanillaNeRF( + net_depth=4, # The depth of the MLP + net_width=256, # The width of the MLP + skip_layer=3, # The layer to add skip layers to. + feature_dim=nerf_feature_dim + 6, # + RGB original imgs + net_depth_condition=1, # The depth of the second part of MLP + net_width_condition=128) + self.nerf_mode = nerf_mode + self.nerf_density = nerf_density + self.nerf_sample_view = nerf_sample_view + self.render_testing = render_testing + + # hard code here, will deal with batch issue later. + self.cov = nn.Sequential( + nn.Conv3d( + neck['out_channels'], + neck['out_channels'], + kernel_size=3, + padding=1), nn.ReLU(inplace=True), + nn.Conv3d( + neck['out_channels'], + neck['out_channels'], + kernel_size=3, + padding=1), nn.ReLU(inplace=True), + nn.Conv3d(neck['out_channels'], 1, kernel_size=1)) + + self.mean_mapping = nn.Sequential( + nn.Conv3d( + neck['out_channels'], nerf_feature_dim // 2, kernel_size=1)) + + self.cov_mapping = nn.Sequential( + nn.Conv3d( + neck['out_channels'], nerf_feature_dim // 2, kernel_size=1)) + + self.mapping = nn.Sequential( + nn.Linear(neck['out_channels'], nerf_feature_dim // 2)) + + self.mapping_2d = nn.Sequential( + nn.Conv2d( + neck['out_channels'], nerf_feature_dim // 2, kernel_size=1)) + # self.overfit_nerfmlp = overfit_nerfmlp + # if self.overfit_nerfmlp: + # self. _finetuning_NeRF_MLP() + self.render_testing = render_testing + + def extract_feat(self, + batch_inputs_dict: dict, + batch_data_samples: SampleList, + mode, + depth=None, + ray_batch=None): + """Extract 3d features from the backbone -> fpn -> 3d projection. + + -> 3d neck -> bbox_head. + + Args: + batch_inputs_dict (dict): The model input dict which include + the 'imgs' key. + + - imgs (torch.Tensor, optional): Image of each sample. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instances` of `gt_panoptic_seg` or `gt_sem_seg` + + Returns: + Tuple: + - torch.Tensor: Features of shape (N, C_out, N_x, N_y, N_z). + - torch.Tensor: Valid mask of shape (N, 1, N_x, N_y, N_z). + - torch.Tensor: 2D features if needed. + - dict: The nerf rendered information including the + 'output_coarse', 'gt_rgb' and 'gt_depth' keys. + """ + img = batch_inputs_dict['imgs'] + img = img.float() + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + batch_size = img.shape[0] + + if len(img.shape) > 4: + img = img.reshape([-1] + list(img.shape)[2:]) + x = self.backbone(img) + x = self.neck(x)[0] + x = x.reshape([batch_size, -1] + list(x.shape[1:])) + else: + x = self.backbone(img) + x = self.neck(x)[0] + + if depth is not None: + depth_bs = depth.shape[0] + assert depth_bs == batch_size + depth = batch_inputs_dict['depth'] + depth = depth.reshape([-1] + list(depth.shape)[2:]) + + features_2d = self.head_2d.forward(x[-1], batch_img_metas) \ + if self.head_2d is not None else None + + stride = img.shape[-1] / x.shape[-1] + assert stride == 4 + stride = int(stride) + + volumes, valids = [], [] + rgb_preds = [] + + for feature, img_meta in zip(x, batch_img_metas): + angles = features_2d[ + 0] if features_2d is not None and mode == 'test' else None + projection = self._compute_projection(img_meta, stride, + angles).to(x.device) + points = get_points( + n_voxels=torch.tensor(self.n_voxels), + voxel_size=torch.tensor(self.voxel_size), + origin=torch.tensor(img_meta['lidar2img']['origin'])).to( + x.device) + + height = img_meta['img_shape'][0] // stride + width = img_meta['img_shape'][1] // stride + # Construct the volume space + # volume together with valid is the constructed scene + # volume represents V_i and valid represents M_p + volume, valid = backproject(feature[:, :, :height, :width], points, + projection, depth, self.voxel_size) + density = None + volume_sum = volume.sum(dim=0) + # cov_valid = valid.clone().detach() + valid = valid.sum(dim=0) + volume_mean = volume_sum / (valid + 1e-8) + volume_mean[:, valid[0] == 0] = .0 + # volume_cov = (volume - volume_mean.unsqueeze(0)) ** 2 * cov_valid + # volume_cov = torch.sum(volume_cov, dim=0) / (valid + 1e-8) + volume_cov = torch.sum( + (volume - volume_mean.unsqueeze(0))**2, dim=0) / ( + valid + 1e-8) + volume_cov[:, valid[0] == 0] = 1e6 + volume_cov = torch.exp(-volume_cov) # default setting + # be careful here, the smaller the cov, the larger the weight. + n_channels, n_x_voxels, n_y_voxels, n_z_voxels = volume_mean.shape + if ray_batch is not None: + if self.nerf_mode == 'volume': + mean_volume = self.mean_mapping(volume_mean.unsqueeze(0)) + cov_volume = self.cov_mapping(volume_cov.unsqueeze(0)) + feature_2d = feature[:, :, :height, :width] + + elif self.nerf_mode == 'image': + mean_volume = None + cov_volume = None + feature_2d = feature[:, :, :height, :width] + n_v, C, height, width = feature_2d.shape + feature_2d = feature_2d.view(n_v, C, + -1).permute(0, 2, + 1).contiguous() + feature_2d = self.mapping(feature_2d).permute( + 0, 2, 1).contiguous().view(n_v, -1, height, width) + + denorm_images = ray_batch['denorm_images'] + denorm_images = denorm_images.reshape( + [-1] + list(denorm_images.shape)[2:]) + rgb_projection = self._compute_projection( + img_meta, stride=1, angles=None).to(x.device) + + rgb_volume, _ = backproject( + denorm_images[:, :, :img_meta['img_shape'][0], : + img_meta['img_shape'][1]], points, + rgb_projection, depth, self.voxel_size) + + ret = render_rays( + ray_batch, + mean_volume, + cov_volume, + feature_2d, + denorm_images, + self.aabb, + self.near_far_range, + self.N_samples, + self.N_rand, + self.nerf_mlp, + img_meta, + self.projector, + self.nerf_mode, + self.nerf_sample_view, + is_train=True if mode == 'train' else False, + render_testing=self.render_testing) + rgb_preds.append(ret) + + if self.nerf_density: + # would have 0 bias issue for mean_mapping. + n_v, C, n_x_voxels, n_y_voxels, n_z_voxels = volume.shape + volume = volume.view(n_v, C, -1).permute(0, 2, + 1).contiguous() + mapping_volume = self.mapping(volume).permute( + 0, 2, 1).contiguous().view(n_v, -1, n_x_voxels, + n_y_voxels, n_z_voxels) + + mapping_volume = torch.cat([rgb_volume, mapping_volume], + dim=1) + mapping_volume_sum = mapping_volume.sum(dim=0) + mapping_volume_mean = mapping_volume_sum / (valid + 1e-8) + # mapping_volume_cov = ( + # mapping_volume - mapping_volume_mean.unsqueeze(0) + # ) ** 2 * cov_valid + mapping_volume_cov = (mapping_volume - + mapping_volume_mean.unsqueeze(0))**2 + mapping_volume_cov = torch.sum( + mapping_volume_cov, dim=0) / ( + valid + 1e-8) + mapping_volume_cov[:, valid[0] == 0] = 1e6 + mapping_volume_cov = torch.exp( + -mapping_volume_cov) # default setting + global_volume = torch.cat( + [mapping_volume_mean, mapping_volume_cov], dim=1) + global_volume = global_volume.view( + -1, n_x_voxels * n_y_voxels * n_z_voxels).permute( + 1, 0).contiguous() + points = points.view(3, -1).permute(1, 0).contiguous() + density = self.nerf_mlp.query_density( + points, global_volume) + alpha = 1 - torch.exp(-density) + # density -> alpha + # (1, n_x_voxels, n_y_voxels, n_z_voxels) + volume = alpha.view(1, n_x_voxels, n_y_voxels, + n_z_voxels) * volume_mean + volume[:, valid[0] == 0] = .0 + + volumes.append(volume) + valids.append(valid) + x = torch.stack(volumes) + x = self.neck_3d(x) + + return x, torch.stack(valids).float(), features_2d, rgb_preds + + def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, + **kwargs) -> Union[dict, list]: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs_dict (dict): The model input dict which include + the 'imgs' key. + + - imgs (torch.Tensor, optional): Image of each sample. + batch_data_samples (list[:obj: `DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + ray_batchs = {} + batch_images = [] + batch_depths = [] + if 'images' in batch_data_samples[0].gt_nerf_images: + for data_samples in batch_data_samples: + image = data_samples.gt_nerf_images['images'] + batch_images.append(image) + batch_images = torch.stack(batch_images) + + if 'depths' in batch_data_samples[0].gt_nerf_depths: + for data_samples in batch_data_samples: + depth = data_samples.gt_nerf_depths['depths'] + batch_depths.append(depth) + batch_depths = torch.stack(batch_depths) + + if 'raydirs' in batch_inputs_dict.keys(): + ray_batchs['ray_o'] = batch_inputs_dict['lightpos'] + ray_batchs['ray_d'] = batch_inputs_dict['raydirs'] + ray_batchs['gt_rgb'] = batch_images + ray_batchs['gt_depth'] = batch_depths + ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes'] + ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images'] + x, valids, features_2d, rgb_preds = self.extract_feat( + batch_inputs_dict, + batch_data_samples, + 'train', + depth=None, + ray_batch=ray_batchs) + else: + x, valids, features_2d, rgb_preds = self.extract_feat( + batch_inputs_dict, batch_data_samples, 'train') + x += (valids, ) + losses = self.bbox_head.loss(x, batch_data_samples, **kwargs) + + # if self.head_2d is not None: + # losses.update( + # self.head_2d.loss(*features_2d, batch_data_samples) + # ) + if len(ray_batchs) != 0 and self.rgb_supervision: + losses.update(self.nvs_loss_func(rgb_preds)) + if self.depth_supervise: + losses.update(self.depth_loss_func(rgb_preds)) + return losses + + def nvs_loss_func(self, rgb_pred): + loss = 0 + for ret in rgb_pred: + rgb = ret['outputs_coarse']['rgb'] + gt = ret['gt_rgb'] + masks = ret['outputs_coarse']['mask'] + if self.use_nerf_mask: + loss += torch.sum(masks.unsqueeze(-1) * (rgb - gt)**2) / ( + masks.sum() + 1e-6) + else: + loss += torch.mean((rgb - gt)**2) + return dict(loss_nvs=loss) + + def depth_loss_func(self, rgb_pred): + loss = 0 + for ret in rgb_pred: + depth = ret['outputs_coarse']['depth'] + gt = ret['gt_depth'].squeeze(-1) + masks = ret['outputs_coarse']['mask'] + if self.use_nerf_mask: + loss += torch.sum(masks * torch.abs(depth - gt)) / ( + masks.sum() + 1e-6) + else: + loss += torch.mean(torch.abs(depth - gt)) + + return dict(loss_depth=loss) + + def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList, + **kwargs) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs_dict (dict): The model input dict which include + the 'imgs' key. + + - imgs (torch.Tensor, optional): Image of each sample. + + batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + + Returns: + list[:obj:`NeRFDet3DDataSample`]: Detection results of the + input images. Each NeRFDet3DDataSample usually contain + 'pred_instances_3d'. And the ``pred_instances_3d`` usually + contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instance, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (Tensor): Contains a tensor with shape + (num_instances, C) where C = 6. + """ + ray_batchs = {} + batch_images = [] + batch_depths = [] + if 'images' in batch_data_samples[0].gt_nerf_images: + for data_samples in batch_data_samples: + image = data_samples.gt_nerf_images['images'] + batch_images.append(image) + batch_images = torch.stack(batch_images) + + if 'depths' in batch_data_samples[0].gt_nerf_depths: + for data_samples in batch_data_samples: + depth = data_samples.gt_nerf_depths['depths'] + batch_depths.append(depth) + batch_depths = torch.stack(batch_depths) + + if 'raydirs' in batch_inputs_dict.keys(): + ray_batchs['ray_o'] = batch_inputs_dict['lightpos'] + ray_batchs['ray_d'] = batch_inputs_dict['raydirs'] + ray_batchs['gt_rgb'] = batch_images + ray_batchs['gt_depth'] = batch_depths + ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes'] + ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images'] + x, valids, features_2d, rgb_preds = self.extract_feat( + batch_inputs_dict, + batch_data_samples, + 'test', + depth=None, + ray_batch=ray_batchs) + else: + x, valids, features_2d, rgb_preds = self.extract_feat( + batch_inputs_dict, batch_data_samples, 'test') + x += (valids, ) + results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs) + predictions = self.add_pred_to_datasample(batch_data_samples, + results_list) + return predictions + + def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList, + *args, **kwargs) -> Tuple[List[torch.Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs_dict (dict): The model input dict which include + the 'imgs' key. + + - imgs (torch.Tensor, optional): Image of each sample. + + batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d` + + Returns: + tuple[list]: A tuple of features from ``bbox_head`` forward + """ + ray_batchs = {} + batch_images = [] + batch_depths = [] + if 'images' in batch_data_samples[0].gt_nerf_images: + for data_samples in batch_data_samples: + image = data_samples.gt_nerf_images['images'] + batch_images.append(image) + batch_images = torch.stack(batch_images) + + if 'depths' in batch_data_samples[0].gt_nerf_depths: + for data_samples in batch_data_samples: + depth = data_samples.gt_nerf_depths['depths'] + batch_depths.append(depth) + batch_depths = torch.stack(batch_depths) + if 'raydirs' in batch_inputs_dict.keys(): + ray_batchs['ray_o'] = batch_inputs_dict['lightpos'] + ray_batchs['ray_d'] = batch_inputs_dict['raydirs'] + ray_batchs['gt_rgb'] = batch_images + ray_batchs['gt_depth'] = batch_depths + ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes'] + ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images'] + x, valids, features_2d, rgb_preds = self.extract_feat( + batch_inputs_dict, + batch_data_samples, + 'train', + depth=None, + ray_batch=ray_batchs) + else: + x, valids, features_2d, rgb_preds = self.extract_feat( + batch_inputs_dict, batch_data_samples, 'train') + x += (valids, ) + results = self.bbox_head.forward(x) + return results + + def aug_test(self, batch_inputs_dict, batch_data_samples): + pass + + def show_results(self, *args, **kwargs): + pass + + @staticmethod + def _compute_projection(img_meta, stride, angles): + projection = [] + intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:3, :3]) + ratio = img_meta['ori_shape'][0] / (img_meta['img_shape'][0] / stride) + intrinsic[:2] /= ratio + # use predict pitch and roll for SUNRGBDTotal test + if angles is not None: + extrinsics = [] + for angle in angles: + extrinsics.append(get_extrinsics(angle).to(intrinsic.device)) + else: + extrinsics = map(torch.tensor, img_meta['lidar2img']['extrinsic']) + for extrinsic in extrinsics: + projection.append(intrinsic @ extrinsic[:3]) + return torch.stack(projection) + + +@torch.no_grad() +def get_points(n_voxels, voxel_size, origin): + # origin: point-cloud center. + points = torch.stack( + torch.meshgrid([ + torch.arange(n_voxels[0]), # 40 W width, x + torch.arange(n_voxels[1]), # 40 D depth, y + torch.arange(n_voxels[2]) # 16 H Height, z + ])) + new_origin = origin - n_voxels / 2. * voxel_size + points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1) + return points + + +# modify from https://github.com/magicleap/Atlas/blob/master/atlas/model.py +def backproject(features, points, projection, depth, voxel_size): + n_images, n_channels, height, width = features.shape + n_x_voxels, n_y_voxels, n_z_voxels = points.shape[-3:] + points = points.view(1, 3, -1).expand(n_images, 3, -1) + points = torch.cat((points, torch.ones_like(points[:, :1])), dim=1) + points_2d_3 = torch.bmm(projection, points) + + x = (points_2d_3[:, 0] / points_2d_3[:, 2]).round().long() + y = (points_2d_3[:, 1] / points_2d_3[:, 2]).round().long() + z = points_2d_3[:, 2] + valid = (x >= 0) & (y >= 0) & (x < width) & (y < height) & (z > 0) + # below is using depth to sample feature + if depth is not None: + depth = F.interpolate( + depth.unsqueeze(1), size=(height, width), + mode='bilinear').squeeze(1) + for i in range(n_images): + z_mask = z.clone() > 0 + z_mask[i, valid[i]] = \ + (z[i, valid[i]] > depth[i, y[i, valid[i]], x[i, valid[i]]] - voxel_size[-1]) & \ + (z[i, valid[i]] < depth[i, y[i, valid[i]], x[i, valid[i]]] + voxel_size[-1]) # noqa + valid = valid & z_mask + + volume = torch.zeros((n_images, n_channels, points.shape[-1]), + device=features.device) + for i in range(n_images): + volume[i, :, valid[i]] = features[i, :, y[i, valid[i]], x[i, valid[i]]] + volume = volume.view(n_images, n_channels, n_x_voxels, n_y_voxels, + n_z_voxels) + valid = valid.view(n_images, 1, n_x_voxels, n_y_voxels, n_z_voxels) + + return volume, valid + + +# for SUNRGBDTotal test +def get_extrinsics(angles): + yaw = angles.new_zeros(()) + pitch, roll = angles + r = angles.new_zeros((3, 3)) + r[0, 0] = torch.cos(yaw) * torch.cos(pitch) + r[0, 1] = torch.sin(yaw) * torch.sin(roll) - torch.cos(yaw) * torch.cos( + roll) * torch.sin(pitch) + r[0, 2] = torch.cos(roll) * torch.sin(yaw) + torch.cos(yaw) * torch.sin( + pitch) * torch.sin(roll) + r[1, 0] = torch.sin(pitch) + r[1, 1] = torch.cos(pitch) * torch.cos(roll) + r[1, 2] = -torch.cos(pitch) * torch.sin(roll) + r[2, 0] = -torch.cos(pitch) * torch.sin(yaw) + r[2, 1] = torch.cos(yaw) * torch.sin(roll) + torch.cos(roll) * torch.sin( + yaw) * torch.sin(pitch) + r[2, 2] = torch.cos(yaw) * torch.cos(roll) - torch.sin(yaw) * torch.sin( + pitch) * torch.sin(roll) + + # follow Total3DUnderstanding + t = angles.new_tensor([[0., 0., 1.], [0., -1., 0.], [-1., 0., 0.]]) + r = t @ r.T + # follow DepthInstance3DBoxes + r = r[:, [2, 0, 1]] + r[2] *= -1 + extrinsic = angles.new_zeros((4, 4)) + extrinsic[:3, :3] = r + extrinsic[3, 3] = 1. + return extrinsic diff --git a/projects/NeRF-Det/nerfdet/nerfdet_head.py b/projects/NeRF-Det/nerfdet/nerfdet_head.py new file mode 100644 index 0000000000..d5faa0adc1 --- /dev/null +++ b/projects/NeRF-Det/nerfdet/nerfdet_head.py @@ -0,0 +1,629 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from mmcv.cnn import Scale +# from mmcv.ops import nms3d, nms3d_normal +from mmdet.models.utils import multi_apply +from mmdet.utils import reduce_mean +# from mmengine.config import ConfigDict +from mmengine.model import BaseModule, bias_init_with_prob, normal_init +from mmengine.structures import InstanceData +from torch import Tensor, nn + +from mmdet3d.registry import MODELS, TASK_UTILS +# from mmdet3d.structures.bbox_3d.utils import rotation_3d_in_axis +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils.typing_utils import (ConfigType, InstanceList, + OptConfigType, OptInstanceList) + + +@torch.no_grad() +def get_points(n_voxels, voxel_size, origin): + # origin: point-cloud center. + points = torch.stack( + torch.meshgrid([ + torch.arange(n_voxels[0]), # 40 W width, x + torch.arange(n_voxels[1]), # 40 D depth, y + torch.arange(n_voxels[2]) # 16 H Height, z + ])) + new_origin = origin - n_voxels / 2. * voxel_size + points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1) + return points + + +@MODELS.register_module() +class NerfDetHead(BaseModule): + r"""`ImVoxelNet`_ head for indoor + datasets. + + Args: + n_classes (int): Number of classes. + n_levels (int): Number of feature levels. + n_channels (int): Number of channels in input tensors. + n_reg_outs (int): Number of regression layer channels. + pts_assign_threshold (int): Min number of location per box to + be assigned with. + pts_center_threshold (int): Max number of locations per box to + be assigned with. + center_loss (dict, optional): Config of centerness loss. + Default: dict(type='CrossEntropyLoss', use_sigmoid=True). + bbox_loss (dict, optional): Config of bbox loss. + Default: dict(type='RotatedIoU3DLoss'). + cls_loss (dict, optional): Config of classification loss. + Default: dict(type='FocalLoss'). + train_cfg (dict, optional): Config for train stage. Defaults to None. + test_cfg (dict, optional): Config for test stage. Defaults to None. + init_cfg (dict, optional): Config for weight initialization. + Defaults to None. + """ + + def __init__(self, + n_classes: int, + n_levels: int, + n_channels: int, + n_reg_outs: int, + pts_assign_threshold: int, + pts_center_threshold: int, + prior_generator: ConfigType, + center_loss: ConfigType = dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True), + bbox_loss: ConfigType = dict(type='RotatedIoU3DLoss'), + cls_loss: ConfigType = dict(type='mmdet.FocalLoss'), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptConfigType = None): + super(NerfDetHead, self).__init__(init_cfg) + self.n_classes = n_classes + self.n_levels = n_levels + self.n_reg_outs = n_reg_outs + self.pts_assign_threshold = pts_assign_threshold + self.pts_center_threshold = pts_center_threshold + self.prior_generator = TASK_UTILS.build(prior_generator) + self.center_loss = MODELS.build(center_loss) + self.bbox_loss = MODELS.build(bbox_loss) + self.cls_loss = MODELS.build(cls_loss) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers(n_channels, n_reg_outs, n_classes, n_levels) + + def _init_layers(self, n_channels, n_reg_outs, n_classes, n_levels): + """Initialize neural network layers of the head.""" + self.conv_center = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False) + self.conv_reg = nn.Conv3d( + n_channels, n_reg_outs, 3, padding=1, bias=False) + self.conv_cls = nn.Conv3d(n_channels, n_classes, 3, padding=1) + self.scales = nn.ModuleList([Scale(1.) for _ in range(n_levels)]) + + def init_weights(self): + """Initialize all layer weights.""" + normal_init(self.conv_center, std=.01) + normal_init(self.conv_reg, std=.01) + normal_init(self.conv_cls, std=.01, bias=bias_init_with_prob(.01)) + + def _forward_single(self, x: Tensor, scale: Scale): + """Forward pass per level. + + Args: + x (Tensor): Per level 3d neck output tensor. + scale (mmcv.cnn.Scale): Per level multiplication weight. + + Returns: + tuple[Tensor]: Centerness, bbox and classification predictions. + """ + return (self.conv_center(x), torch.exp(scale(self.conv_reg(x))), + self.conv_cls(x)) + + def forward(self, x): + return multi_apply(self._forward_single, x, self.scales) + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + **kwargs) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + valid_pred = x[-1] + outs = self(x[:-1]) + + batch_gt_instances_3d = [] + batch_gt_instances_ignore = [] + batch_input_metas = [] + for data_sample in batch_data_samples: + batch_input_metas.append(data_sample.metainfo) + batch_gt_instances_3d.append(data_sample.gt_instances_3d) + batch_gt_instances_ignore.append( + data_sample.get('ignored_instances', None)) + + loss_inputs = outs + (valid_pred, batch_gt_instances_3d, + batch_input_metas, batch_gt_instances_ignore) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat(self, + center_preds: List[List[Tensor]], + bbox_preds: List[List[Tensor]], + cls_preds: List[List[Tensor]], + valid_pred: Tensor, + batch_gt_instances_3d: InstanceList, + batch_input_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + **kwargs) -> dict: + """Per scene loss function. + + Args: + center_preds (list[list[Tensor]]): Centerness predictions for + all scenes. The first list contains predictions from different + levels. The second list contains predictions in a mini-batch. + bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. + The first list contains predictions from different + levels. The second list contains predictions in a mini-batch. + cls_preds (list[list[Tensor]]): Classification predictions for all + scenes. The first list contains predictions from different + levels. The second list contains predictions in a mini-batch. + valid_pred (Tensor): Valid mask prediction for all scenes. + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instance_3d. It usually includes ``bboxes_3d``、` + `labels_3d``、``depths``、``centers_2d`` and attributes. + batch_input_metas (list[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: Centerness, bbox, and classification loss values. + """ + valid_preds = self._upsample_valid_preds(valid_pred, center_preds) + center_losses, bbox_losses, cls_losses = [], [], [] + for i in range(len(batch_input_metas)): + center_loss, bbox_loss, cls_loss = self._loss_by_feat_single( + center_preds=[x[i] for x in center_preds], + bbox_preds=[x[i] for x in bbox_preds], + cls_preds=[x[i] for x in cls_preds], + valid_preds=[x[i] for x in valid_preds], + input_meta=batch_input_metas[i], + gt_bboxes=batch_gt_instances_3d[i].bboxes_3d, + gt_labels=batch_gt_instances_3d[i].labels_3d) + center_losses.append(center_loss) + bbox_losses.append(bbox_loss) + cls_losses.append(cls_loss) + return dict( + center_loss=torch.mean(torch.stack(center_losses)), + bbox_loss=torch.mean(torch.stack(bbox_losses)), + cls_loss=torch.mean(torch.stack(cls_losses))) + + def _loss_by_feat_single(self, center_preds, bbox_preds, cls_preds, + valid_preds, input_meta, gt_bboxes, gt_labels): + featmap_sizes = [featmap.size()[-3:] for featmap in center_preds] + points = self._get_points( + featmap_sizes=featmap_sizes, + origin=input_meta['lidar2img']['origin'], + device=gt_bboxes.device) + center_targets, bbox_targets, cls_targets = self._get_targets( + points, gt_bboxes, gt_labels) + + center_preds = torch.cat( + [x.permute(1, 2, 3, 0).reshape(-1) for x in center_preds]) + bbox_preds = torch.cat([ + x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in bbox_preds + ]) + cls_preds = torch.cat( + [x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in cls_preds]) + valid_preds = torch.cat( + [x.permute(1, 2, 3, 0).reshape(-1) for x in valid_preds]) + points = torch.cat(points) + + # cls loss + pos_inds = torch.nonzero( + torch.logical_and(cls_targets >= 0, valid_preds)).squeeze(1) + n_pos = points.new_tensor(len(pos_inds)) + n_pos = max(reduce_mean(n_pos), 1.) + if torch.any(valid_preds): + cls_loss = self.cls_loss( + cls_preds[valid_preds], + cls_targets[valid_preds], + avg_factor=n_pos) + else: + cls_loss = cls_preds[valid_preds].sum() + + # bbox and centerness losses + pos_center_preds = center_preds[pos_inds] + pos_bbox_preds = bbox_preds[pos_inds] + if len(pos_inds) > 0: + pos_center_targets = center_targets[pos_inds] + pos_bbox_targets = bbox_targets[pos_inds] + pos_points = points[pos_inds] + center_loss = self.center_loss( + pos_center_preds, pos_center_targets, avg_factor=n_pos) + bbox_loss = self.bbox_loss( + self._bbox_pred_to_bbox(pos_points, pos_bbox_preds), + pos_bbox_targets, + weight=pos_center_targets, + avg_factor=pos_center_targets.sum()) + else: + center_loss = pos_center_preds.sum() + bbox_loss = pos_bbox_preds.sum() + return center_loss, bbox_loss, cls_loss + + def predict(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the 3D detection head and predict + detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_3d`, `gt_pts_panoptic_seg` and + `gt_pts_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 6. + """ + batch_input_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + valid_pred = x[-1] + outs = self(x[:-1]) + predictions = self.predict_by_feat( + *outs, + valid_pred=valid_pred, + batch_input_metas=batch_input_metas, + rescale=rescale) + return predictions + + def predict_by_feat(self, center_preds: List[List[Tensor]], + bbox_preds: List[List[Tensor]], + cls_preds: List[List[Tensor]], valid_pred: Tensor, + batch_input_metas: List[dict], + **kwargs) -> List[InstanceData]: + """Generate boxes for all scenes. + + Args: + center_preds (list[list[Tensor]]): Centerness predictions for + all scenes. + bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. + cls_preds (list[list[Tensor]]): Classification predictions for all + scenes. + valid_pred (Tensor): Valid mask prediction for all scenes. + batch_input_metas (list[dict]): Meta infos for all scenes. + + Returns: + list[tuple[Tensor]]: Predicted bboxes, scores, and labels for + all scenes. + """ + valid_preds = self._upsample_valid_preds(valid_pred, center_preds) + results = [] + for i in range(len(batch_input_metas)): + results.append( + self._predict_by_feat_single( + center_preds=[x[i] for x in center_preds], + bbox_preds=[x[i] for x in bbox_preds], + cls_preds=[x[i] for x in cls_preds], + valid_preds=[x[i] for x in valid_preds], + input_meta=batch_input_metas[i])) + return results + + def _predict_by_feat_single(self, center_preds: List[Tensor], + bbox_preds: List[Tensor], + cls_preds: List[Tensor], + valid_preds: List[Tensor], + input_meta: dict) -> InstanceData: + """Generate boxes for single sample. + + Args: + center_preds (list[Tensor]): Centerness predictions for all levels. + bbox_preds (list[Tensor]): Bbox predictions for all levels. + cls_preds (list[Tensor]): Classification predictions for all + levels. + valid_preds (tuple[Tensor]): Upsampled valid masks for all feature + levels. + input_meta (dict): Scene meta info. + + Returns: + tuple[Tensor]: Predicted bounding boxes, scores and labels. + """ + featmap_sizes = [featmap.size()[-3:] for featmap in center_preds] + points = self._get_points( + featmap_sizes=featmap_sizes, + origin=input_meta['lidar2img']['origin'], + device=center_preds[0].device) + mlvl_bboxes, mlvl_scores = [], [] + for center_pred, bbox_pred, cls_pred, valid_pred, point in zip( + center_preds, bbox_preds, cls_preds, valid_preds, points): + center_pred = center_pred.permute(1, 2, 3, 0).reshape(-1, 1) + bbox_pred = bbox_pred.permute(1, 2, 3, + 0).reshape(-1, bbox_pred.shape[0]) + cls_pred = cls_pred.permute(1, 2, 3, + 0).reshape(-1, cls_pred.shape[0]) + valid_pred = valid_pred.permute(1, 2, 3, 0).reshape(-1, 1) + scores = cls_pred.sigmoid() * center_pred.sigmoid() * valid_pred + max_scores, _ = scores.max(dim=1) + + if len(scores) > self.test_cfg.nms_pre > 0: + _, ids = max_scores.topk(self.test_cfg.nms_pre) + bbox_pred = bbox_pred[ids] + scores = scores[ids] + point = point[ids] + + bboxes = self._bbox_pred_to_bbox(point, bbox_pred) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + + bboxes = torch.cat(mlvl_bboxes) + scores = torch.cat(mlvl_scores) + bboxes, scores, labels = self._nms(bboxes, scores, input_meta) + + bboxes = input_meta['box_type_3d']( + bboxes, box_dim=6, with_yaw=False, origin=(.5, .5, .5)) + + results = InstanceData() + results.bboxes_3d = bboxes + results.scores_3d = scores + results.labels_3d = labels + return results + + @staticmethod + def _upsample_valid_preds(valid_pred, features): + """Upsample valid mask predictions. + + Args: + valid_pred (Tensor): Valid mask prediction. + features (Tensor): Feature tensor. + + Returns: + tuple[Tensor]: Upsampled valid masks for all feature levels. + """ + return [ + nn.Upsample(size=x.shape[-3:], + mode='trilinear')(valid_pred).round().bool() + for x in features + ] + + @torch.no_grad() + def _get_points(self, featmap_sizes, origin, device): + mlvl_points = [] + tmp_voxel_size = [.16, .16, .2] + for i, featmap_size in enumerate(featmap_sizes): + mlvl_points.append( + get_points( + n_voxels=torch.tensor(featmap_size), + voxel_size=torch.tensor(tmp_voxel_size) * (2**i), + origin=torch.tensor(origin)).reshape(3, -1).transpose( + 0, 1).to(device)) + return mlvl_points + + def _bbox_pred_to_bbox(self, points, bbox_pred): + return torch.stack([ + points[:, 0] - bbox_pred[:, 0], points[:, 1] - bbox_pred[:, 2], + points[:, 2] - bbox_pred[:, 4], points[:, 0] + bbox_pred[:, 1], + points[:, 1] + bbox_pred[:, 3], points[:, 2] + bbox_pred[:, 5] + ], -1) + + def _bbox_pred_to_loss(self, points, bbox_preds): + return self._bbox_pred_to_bbox(points, bbox_preds) + + # The function is directly copied from FCAF3DHead. + @staticmethod + def _get_face_distances(points, boxes): + """Calculate distances from point to box faces. + + Args: + points (Tensor): Final locations of shape (N_points, N_boxes, 3). + boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7) + + Returns: + Tensor: Face distances of shape (N_points, N_boxes, 6), + (dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). + """ + dx_min = points[..., 0] - boxes[..., 0] + boxes[..., 3] / 2 + dx_max = boxes[..., 0] + boxes[..., 3] / 2 - points[..., 0] + dy_min = points[..., 1] - boxes[..., 1] + boxes[..., 4] / 2 + dy_max = boxes[..., 1] + boxes[..., 4] / 2 - points[..., 1] + dz_min = points[..., 2] - boxes[..., 2] + boxes[..., 5] / 2 + dz_max = boxes[..., 2] + boxes[..., 5] / 2 - points[..., 2] + return torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max), + dim=-1) + + @staticmethod + def _get_centerness(face_distances): + """Compute point centerness w.r.t containing box. + + Args: + face_distances (Tensor): Face distances of shape (B, N, 6), + (dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). + + Returns: + Tensor: Centerness of shape (B, N). + """ + x_dims = face_distances[..., [0, 1]] + y_dims = face_distances[..., [2, 3]] + z_dims = face_distances[..., [4, 5]] + centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \ + y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \ + z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0] + return torch.sqrt(centerness_targets) + + @torch.no_grad() + def _get_targets(self, points, gt_bboxes, gt_labels): + """Compute targets for final locations for a single scene. + + Args: + points (list[Tensor]): Final locations for all levels. + gt_bboxes (BaseInstance3DBoxes): Ground truth boxes. + gt_labels (Tensor): Ground truth labels. + + Returns: + tuple[Tensor]: Centerness, bbox and classification + targets for all locations. + """ + float_max = 1e8 + expanded_scales = [ + points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device) + for i in range(len(points)) + ] + points = torch.cat(points, dim=0).to(gt_labels.device) + scales = torch.cat(expanded_scales, dim=0) + + # below is based on FCOSHead._get_target_single + n_points = len(points) + n_boxes = len(gt_bboxes) + volumes = gt_bboxes.volume.to(points.device) + volumes = volumes.expand(n_points, n_boxes).contiguous() + gt_bboxes = torch.cat( + (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:6]), dim=1) + gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 6) + expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3) + bbox_targets = self._get_face_distances(expanded_points, gt_bboxes) + + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_targets[..., :6].min( + -1)[0] > 0 # skip angle + + # condition2: positive points per scale >= limit + # calculate positive points per scale + n_pos_points_per_scale = [] + for i in range(self.n_levels): + n_pos_points_per_scale.append( + torch.sum(inside_gt_bbox_mask[scales == i], dim=0)) + # find best scale + n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0) + lower_limit_mask = n_pos_points_per_scale < self.pts_assign_threshold + # fix nondeterministic argmax for torch<1.7 + extra = torch.arange(self.n_levels, 0, -1).unsqueeze(1).expand( + self.n_levels, n_boxes).to(lower_limit_mask.device) + lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1 + lower_index = torch.where(lower_index < 0, + torch.zeros_like(lower_index), lower_index) + all_upper_limit_mask = torch.all( + torch.logical_not(lower_limit_mask), dim=0) + best_scale = torch.where( + all_upper_limit_mask, + torch.ones_like(all_upper_limit_mask) * self.n_levels - 1, + lower_index) + # keep only points with best scale + best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes) + scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes) + inside_best_scale_mask = best_scale == scales + + # condition3: limit topk locations per box by centerness + centerness = self._get_centerness(bbox_targets) + centerness = torch.where(inside_gt_bbox_mask, centerness, + torch.ones_like(centerness) * -1) + centerness = torch.where(inside_best_scale_mask, centerness, + torch.ones_like(centerness) * -1) + top_centerness = torch.topk( + centerness, self.pts_center_threshold + 1, dim=0).values[-1] + inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0) + + # if there are still more than one objects for a location, + # we choose the one with minimal area + volumes = torch.where(inside_gt_bbox_mask, volumes, + torch.ones_like(volumes) * float_max) + volumes = torch.where(inside_best_scale_mask, volumes, + torch.ones_like(volumes) * float_max) + volumes = torch.where(inside_top_centerness_mask, volumes, + torch.ones_like(volumes) * float_max) + min_area, min_area_inds = volumes.min(dim=1) + + labels = gt_labels[min_area_inds] + labels = torch.where(min_area == float_max, + torch.ones_like(labels) * -1, labels) + bbox_targets = bbox_targets[range(n_points), min_area_inds] + centerness_targets = self._get_centerness(bbox_targets) + + return centerness_targets, self._bbox_pred_to_bbox( + points, bbox_targets), labels + + def _nms(self, bboxes, scores, img_meta): + scores, labels = scores.max(dim=1) + ids = scores > self.test_cfg.score_thr + bboxes = bboxes[ids] + scores = scores[ids] + labels = labels[ids] + ids = self.aligned_3d_nms(bboxes, scores, labels, + self.test_cfg.iou_thr) + bboxes = bboxes[ids] + bboxes = torch.stack( + ((bboxes[:, 0] + bboxes[:, 3]) / 2., + (bboxes[:, 1] + bboxes[:, 4]) / 2., + (bboxes[:, 2] + bboxes[:, 5]) / 2., bboxes[:, 3] - bboxes[:, 0], + bboxes[:, 4] - bboxes[:, 1], bboxes[:, 5] - bboxes[:, 2]), + dim=1) + return bboxes, scores[ids], labels[ids] + + @staticmethod + def aligned_3d_nms(boxes, scores, classes, thresh): + """3d nms for aligned boxes. + + Args: + boxes (torch.Tensor): Aligned box with shape [n, 6]. + scores (torch.Tensor): Scores of each box. + classes (torch.Tensor): Class of each box. + thresh (float): Iou threshold for nms. + + Returns: + torch.Tensor: Indices of selected boxes. + """ + x1 = boxes[:, 0] + y1 = boxes[:, 1] + z1 = boxes[:, 2] + x2 = boxes[:, 3] + y2 = boxes[:, 4] + z2 = boxes[:, 5] + area = (x2 - x1) * (y2 - y1) * (z2 - z1) + zero = boxes.new_zeros(1, ) + + score_sorted = torch.argsort(scores) + pick = [] + while (score_sorted.shape[0] != 0): + last = score_sorted.shape[0] + i = score_sorted[-1] + pick.append(i) + + xx1 = torch.max(x1[i], x1[score_sorted[:last - 1]]) + yy1 = torch.max(y1[i], y1[score_sorted[:last - 1]]) + zz1 = torch.max(z1[i], z1[score_sorted[:last - 1]]) + xx2 = torch.min(x2[i], x2[score_sorted[:last - 1]]) + yy2 = torch.min(y2[i], y2[score_sorted[:last - 1]]) + zz2 = torch.min(z2[i], z2[score_sorted[:last - 1]]) + classes1 = classes[i] + classes2 = classes[score_sorted[:last - 1]] + inter_l = torch.max(zero, xx2 - xx1) + inter_w = torch.max(zero, yy2 - yy1) + inter_h = torch.max(zero, zz2 - zz1) + + inter = inter_l * inter_w * inter_h + iou = inter / (area[i] + area[score_sorted[:last - 1]] - inter) + iou = iou * (classes1 == classes2).float() + score_sorted = score_sorted[torch.nonzero( + iou <= thresh, as_tuple=False).flatten()] + + indices = boxes.new_tensor(pick, dtype=torch.long) + return indices diff --git a/projects/NeRF-Det/nerfdet/scannet_multiview_dataset.py b/projects/NeRF-Det/nerfdet/scannet_multiview_dataset.py new file mode 100644 index 0000000000..a20bc3eec0 --- /dev/null +++ b/projects/NeRF-Det/nerfdet/scannet_multiview_dataset.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from os import path as osp +from typing import Callable, List, Optional, Union + +import numpy as np + +from mmdet3d.datasets import Det3DDataset +from mmdet3d.registry import DATASETS +from mmdet3d.structures import DepthInstance3DBoxes + + +@DATASETS.register_module() +class MultiViewScanNetDataset(Det3DDataset): + r"""Multi-View ScanNet Dataset for NeRF-detection Task + + This class serves as the API for experiments on the ScanNet Dataset. + + Please refer to the `github repo `_ + for data downloading. + + Args: + data_root (str): Path of dataset root. + ann_file (str): Path of annotation file. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + pipeline (List[dict]): Pipeline used for data processing. + Defaults to []. + modality (dict): Modality to specify the sensor data used as input. + Defaults to dict(use_camera=True, use_lidar=False). + box_type_3d (str): Type of 3D box of this dataset. + Based on the `box_type_3d`, the dataset will encapsulate the box + to its original format then converted them to `box_type_3d`. + Defaults to 'Depth' in this dataset. Available options includes: + + - 'LiDAR': Box in LiDAR coordinates. + - 'Depth': Box in depth coordinates, usually for indoor dataset. + - 'Camera': Box in camera coordinates. + filter_empty_gt (bool): Whether to filter the data with empty GT. + If it's set to be True, the example with empty annotations after + data pipeline will be dropped and a random example will be chosen + in `__getitem__`. Defaults to True. + test_mode (bool): Whether the dataset is in test mode. + Defaults to False. + """ + METAINFO = { + 'classes': + ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', + 'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin') + } + + def __init__(self, + data_root: str, + ann_file: str, + metainfo: Optional[dict] = None, + pipeline: List[Union[dict, Callable]] = [], + modality: dict = dict(use_camera=True, use_lidar=False), + box_type_3d: str = 'Depth', + filter_empty_gt: bool = True, + remove_dontcare: bool = False, + test_mode: bool = False, + **kwargs) -> None: + + self.remove_dontcare = remove_dontcare + + super().__init__( + data_root=data_root, + ann_file=ann_file, + metainfo=metainfo, + pipeline=pipeline, + modality=modality, + box_type_3d=box_type_3d, + filter_empty_gt=filter_empty_gt, + test_mode=test_mode, + **kwargs) + + assert 'use_camera' in self.modality and \ + 'use_lidar' in self.modality + assert self.modality['use_camera'] or self.modality['use_lidar'] + + @staticmethod + def _get_axis_align_matrix(info: dict) -> np.ndarray: + """Get axis_align_matrix from info. If not exist, return identity mat. + + Args: + info (dict): Info of a single sample data. + + Returns: + np.ndarray: 4x4 transformation matrix. + """ + if 'axis_align_matrix' in info: + return np.array(info['axis_align_matrix']) + else: + warnings.warn( + 'axis_align_matrix is not found in ScanNet data info, please ' + 'use new pre-process scripts to re-generate ScanNet data') + return np.eye(4).astype(np.float32) + + def parse_data_info(self, info: dict) -> dict: + """Process the raw data info. + + Convert all relative path of needed modality data file to + the absolute path. + + Args: + info (dict): Raw info dict. + + Returns: + dict: Has `ann_info` in training stage. And + all path has been converted to absolute path. + """ + if self.modality['use_depth']: + info['depth_info'] = [] + if self.modality['use_neuralrecon_depth']: + info['depth_info'] = [] + + if self.modality['use_lidar']: + # implement lidar processing in the future + raise NotImplementedError( + 'Please modified ' + '`MultiViewPipeline` to support lidar processing') + + info['axis_align_matrix'] = self._get_axis_align_matrix(info) + info['img_info'] = [] + info['lidar2img'] = [] + info['c2w'] = [] + info['camrotc2w'] = [] + info['lightpos'] = [] + # load img and depth_img + for i in range(len(info['img_paths'])): + img_filename = osp.join(self.data_root, info['img_paths'][i]) + + info['img_info'].append(dict(filename=img_filename)) + if 'depth_info' in info.keys(): + if self.modality['use_neuralrecon_depth']: + info['depth_info'].append( + dict(filename=img_filename[:-4] + '.npy')) + else: + info['depth_info'].append( + dict(filename=img_filename[:-4] + '.png')) + # implement lidar_info in input.keys() in the future. + extrinsic = np.linalg.inv( + info['axis_align_matrix'] @ info['lidar2cam'][i]) + info['lidar2img'].append(extrinsic.astype(np.float32)) + if self.modality['use_ray']: + c2w = ( + info['axis_align_matrix'] @ info['lidar2cam'][i]).astype( + np.float32) # noqa + info['c2w'].append(c2w) + info['camrotc2w'].append(c2w[0:3, 0:3]) + info['lightpos'].append(c2w[0:3, 3]) + origin = np.array([.0, .0, .5]) + info['lidar2img'] = dict( + extrinsic=info['lidar2img'], + intrinsic=info['cam2img'].astype(np.float32), + origin=origin.astype(np.float32)) + + if self.modality['use_ray']: + info['ray_info'] = [] + + if not self.test_mode: + info['ann_info'] = self.parse_ann_info(info) + if self.test_mode and self.load_eval_anns: + info['ann_info'] = self.parse_ann_info(info) + info['eval_ann_info'] = self._remove_dontcare(info['ann_info']) + + return info + + def parse_ann_info(self, info: dict) -> dict: + """Process the `instances` in data info to `ann_info`. + + Args: + info (dict): Info dict. + + Returns: + dict: Processed `ann_info`. + """ + ann_info = super().parse_ann_info(info) + + if self.remove_dontcare: + ann_info = self._remove_dontcare(ann_info) + + # empty gt + if ann_info is None: + ann_info = dict() + ann_info['gt_bboxes_3d'] = np.zeros((0, 6), dtype=np.float32) + ann_info['gt_labels_3d'] = np.zeros((0, ), dtype=np.int64) + + ann_info['gt_bboxes_3d'] = DepthInstance3DBoxes( + ann_info['gt_bboxes_3d'], + box_dim=ann_info['gt_bboxes_3d'].shape[-1], + with_yaw=False, + origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d) + + # count the numbers + for label in ann_info['gt_labels_3d']: + if label != -1: + cat_name = self.metainfo['classes'][label] + self.num_ins_per_cat[cat_name] += 1 + + return ann_info diff --git a/projects/NeRF-Det/prepare_infos.py b/projects/NeRF-Det/prepare_infos.py new file mode 100644 index 0000000000..3e1a13516f --- /dev/null +++ b/projects/NeRF-Det/prepare_infos.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Prepare the dataset for NeRF-Det. + +Example: + python projects/NeRF-Det/prepare_infos.py + --root-path ./data/scannet + --out-dir ./data/scannet +""" +import argparse +import time +from os import path as osp +from pathlib import Path + +import mmengine + +from ...tools.dataset_converters import indoor_converter as indoor +from ...tools.dataset_converters.update_infos_to_v2 import ( + clear_data_info_unused_keys, clear_instance_unused_keys, + get_empty_instance, get_empty_standard_data_info) + + +def update_scannet_infos_nerfdet(pkl_path, out_dir): + """Update the origin pkl to the new format which will be used in nerf-det. + + Args: + pkl_path (str): Path of the origin pkl. + out_dir (str): Output directory of the generated info file. + + Returns: + The pkl will be overwritTen. + The new pkl is a dict containing two keys: + metainfo: Some base information of the pkl + data_list (list): A list containing all the information of the scenes. + """ + print('The new refactored process is running.') + print(f'{pkl_path} will be modified.') + if out_dir in pkl_path: + print(f'Warning, you may overwriting ' + f'the original data {pkl_path}.') + time.sleep(5) + METAINFO = { + 'classes': + ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', + 'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin') + } + print(f'Reading from input file: {pkl_path}.') + data_list = mmengine.load(pkl_path) + print('Start updating:') + converted_list = [] + for ori_info_dict in mmengine.track_iter_progress(data_list): + temp_data_info = get_empty_standard_data_info() + + # intrinsics, extrinsics and imgs + temp_data_info['cam2img'] = ori_info_dict['intrinsics'] + temp_data_info['lidar2cam'] = ori_info_dict['extrinsics'] + temp_data_info['img_paths'] = ori_info_dict['img_paths'] + + # annotation information + anns = ori_info_dict.get('annos', None) + ignore_class_name = set() + if anns is not None: + temp_data_info['axis_align_matrix'] = anns[ + 'axis_align_matrix'].tolist() + if anns['gt_num'] == 0: + instance_list = [] + else: + num_instances = len(anns['name']) + instance_list = [] + for instance_id in range(num_instances): + empty_instance = get_empty_instance() + empty_instance['bbox_3d'] = anns['gt_boxes_upright_depth'][ + instance_id].tolist() + + if anns['name'][instance_id] in METAINFO['classes']: + empty_instance['bbox_label_3d'] = METAINFO[ + 'classes'].index(anns['name'][instance_id]) + else: + ignore_class_name.add(anns['name'][instance_id]) + empty_instance['bbox_label_3d'] = -1 + + empty_instance = clear_instance_unused_keys(empty_instance) + instance_list.append(empty_instance) + temp_data_info['instances'] = instance_list + temp_data_info, _ = clear_data_info_unused_keys(temp_data_info) + converted_list.append(temp_data_info) + pkl_name = Path(pkl_path).name + out_path = osp.join(out_dir, pkl_name) + print(f'Writing to output file: {out_path}.') + print(f'ignore classes: {ignore_class_name}') + + # dataset metainfo + metainfo = dict() + metainfo['categories'] = {k: i for i, k in enumerate(METAINFO['classes'])} + if ignore_class_name: + for ignore_class in ignore_class_name: + metainfo['categories'][ignore_class] = -1 + metainfo['dataset'] = 'scannet' + metainfo['info_version'] = '1.1' + + converted_data_info = dict(metainfo=metainfo, data_list=converted_list) + + mmengine.dump(converted_data_info, out_path, 'pkl') + + +def scannet_data_prep(root_path, info_prefix, out_dir, workers): + """Prepare the info file for scannet dataset. + + Args: + root_path (str): Path of dataset root. + info_prefix (str): The prefix of info filenames. + out_dir (str): Output directory of the generated info file. + workers (int): Number of threads to be used. + version (str): Only used to generate the dataset of nerfdet now. + """ + indoor.create_indoor_info_file( + root_path, info_prefix, out_dir, workers=workers) + info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl') + info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl') + info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl') + update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_train_path) + update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_val_path) + update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_test_path) + + +parser = argparse.ArgumentParser(description='Data converter arg parser') +parser.add_argument( + '--root-path', + type=str, + default='./data/scannet', + help='specify the root path of dataset') +parser.add_argument( + '--out-dir', + type=str, + default='./data/scannet', + required=False, + help='name of info pkl') +parser.add_argument('--extra-tag', type=str, default='scannet') +parser.add_argument( + '--workers', type=int, default=4, help='number of threads to be used') +args = parser.parse_args() + +if __name__ == '__main__': + from mmdet3d.utils import register_all_modules + register_all_modules() + + scannet_data_prep( + root_path=args.root_path, + info_prefix=args.extra_tag, + out_dir=args.out_dir, + workers=args.workers)