From 82d2a6e37f7bfff50676cdfcebca6df38ab7accb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Mon, 11 Sep 2023 13:54:08 +0800 Subject: [PATCH] Support GLIP Funetune (#10866) --- configs/glip/README.md | 31 +- ...n-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py | 14 + ...t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py | 155 +++++++++ ...t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py | 9 + ...t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py | 3 + ...n-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py | 3 + configs/glip/metafile.yml | 45 +++ docs/en/user_guides/inference.md | 21 +- docs/zh_cn/user_guides/inference.md | 21 +- mmdet/datasets/transforms/__init__.py | 4 +- .../datasets/transforms/transformers_glip.py | 66 ++++ mmdet/datasets/transforms/transforms.py | 1 - mmdet/models/dense_heads/atss_head.py | 2 +- .../models/dense_heads/atss_vlfusion_head.py | 321 +++++++++++++++++- mmdet/models/detectors/glip.py | 106 +++++- mmdet/models/necks/__init__.py | 4 +- mmdet/models/necks/fpn_dropblock.py | 90 +++++ .../models/task_modules/assigners/__init__.py | 4 +- .../assigners/iou2d_calculator.py | 20 ++ mmdet/models/utils/vlfuse_helper.py | 41 ++- 20 files changed, 864 insertions(+), 97 deletions(-) create mode 100644 configs/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py create mode 100644 configs/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py create mode 100644 configs/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py create mode 100644 configs/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py create mode 100644 configs/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py create mode 100644 mmdet/datasets/transforms/transformers_glip.py create mode 100644 mmdet/models/necks/fpn_dropblock.py diff --git a/configs/glip/README.md b/configs/glip/README.md index 5f7c8d3ccb7..ebf5226b109 100644 --- a/configs/glip/README.md +++ b/configs/glip/README.md @@ -31,25 +31,34 @@ wget https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b365 python demo/image_demo.py demo/demo.jpg \ configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py \ -glip_tiny_a_mmdet-b3654169.pth \ ---texts 'bench . car .' +--weights glip_tiny_a_mmdet-b3654169.pth \ +--texts 'bench. car' ```
- +
## Results and Models -| Model | Zero-shot or Funetune | COCO mAP | Pre-Train Data | Config | Download | -| :--------: | :-------------------: | :------: | :------------------------: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------------------------: | -| GLIP-T (A) | Zero-shot | 43.0 | O365 | [config](glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) | -| GLIP-T (B) | Zero-shot | 44.9 | O365 | [config](glip_atss_swin-t_b_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) | -| GLIP-T (C) | Zero-shot | 46.7 | O365,GoldG | [config](glip_atss_swin-t_c_fpn_dyhead_pretrain_obj365-goldg.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) | -| GLIP-T | Zero-shot | 46.4 | O365,GoldG,CC3M,SBU | [config](glip_atss_swin-t_fpn_dyhead_pretrain_obj365-goldg-cc3m-sub.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) | -| GLIP-L | Zero-shot | 51.3 | FourODs,GoldG,CC3M+12M,SBU | [config](glip_atss_swin-l_fpn_dyhead_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) | +| Model | Zero-shot or Funetune | COCO mAP | Pre-Train Data | Config | Download | +| :--------: | :-------------------: | :------: | :------------------------: | :---------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| GLIP-T (A) | Zero-shot | 43.0 | O365 | [config](glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) | +| GLIP-T (A) | Funetune | 53.1 | O365 | [config](glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230908_091856-39f01d03.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230908_091856.log.json) | +| GLIP-T (B) | Zero-shot | 44.9 | O365 | [config](glip_atss_swin-t_b_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) | +| GLIP-T (B) | Funetune | 54.1 | O365 | [config](glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175354-e0c0c6d7.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175354.log.json) | +| GLIP-T (C) | Zero-shot | 46.7 | O365,GoldG | [config](glip_atss_swin-t_c_fpn_dyhead_pretrain_obj365-goldg.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) | +| GLIP-T (C) | Funetune | 55.2 | O365,GoldG | [config](glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175543-5fcb4b97.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175543.log.json) | +| GLIP-T | Zero-shot | 46.4 | O365,GoldG,CC3M,SBU | [config](glip_atss_swin-t_fpn_dyhead_pretrain_obj365-goldg-cc3m-sub.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) | +| GLIP-T | Funetune | 55.2 | O365,GoldG,CC3M,SBU | [config](glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_125111-ad1025a0.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_125111.log.json) | +| GLIP-L | Zero-shot | 51.3 | FourODs,GoldG,CC3M+12M,SBU | [config](glip_atss_swin-l_fpn_dyhead_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) | +| GLIP-L | Funetune | 59.4 | FourODs,GoldG,CC3M+12M,SBU | [config](glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_100800-e9be4274.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_100800.log.json) | Note: 1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/glip_to_mmdet.py). We have not retrained the model for the time being. -2. We will soon support fine-tuning on COCO. +2. Funetune refers to fine-tuning on the COCO 2017 dataset. The L model is trained using 16 A100 GPUs, while the remaining models are trained using 16 NVIDIA GeForce 3090 GPUs. +3. Taking the GLIP-T(A) model as an example, I trained it twice using the official code, and the fine-tuning mAP were 52.5 and 52.6. Therefore, the mAP we achieved in our reproduction is higher than the official results. The main reason is that we modified the `weight_decay` parameter. +4. Our experiments revealed that training for 24 epochs leads to overfitting. Therefore, we chose the best-performing model. If users want to train on a custom dataset, it is advisable to shorten the number of epochs and save the best-performing model. +5. Due to the official absence of fine-tuning hyperparameters for the GLIP-L model, we have not yet reproduced the official accuracy. I have found that overfitting can also occur, so it may be necessary to consider custom modifications to data augmentation and model enhancement. Given the high cost of training, we have not conducted any research on this matter at the moment. +6. We noticed that there is a discrepancy between the performance evaluation of the checkpoint and the evaluation logs during training. This is because the buffers of different ranks are not the same during training, but we only saved the weights of rank 0. If you want to avoid this issue, you can add the parameter `broadcast_buffers=True` in the configuration. diff --git a/configs/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py b/configs/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py new file mode 100644 index 00000000000..92a85a11d57 --- /dev/null +++ b/configs/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py @@ -0,0 +1,14 @@ +_base_ = './glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py' + +model = dict( + backbone=dict( + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + drop_path_rate=0.4, + ), + neck=dict(in_channels=[384, 768, 1536]), + bbox_head=dict(early_fuse=True, num_dyhead_blocks=8, use_checkpoint=True)) + +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth' # noqa diff --git a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py new file mode 100644 index 00000000000..4b280657b31 --- /dev/null +++ b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py @@ -0,0 +1,155 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth' # noqa +lang_model_name = 'bert-base-uncased' + +model = dict( + type='GLIP', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[103.53, 116.28, 123.675], + std=[57.375, 57.12, 58.395], + bgr_to_rgb=False, + pad_size_divisor=32), + backbone=dict( + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=False, + convert_weights=False), + neck=dict( + type='FPN_DropBlock', + in_channels=[192, 384, 768], + out_channels=256, + start_level=0, + relu_before_extra_convs=True, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='ATSSVLFusionHead', + lang_model_name=lang_model_name, + num_classes=80, + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128], + center_offset=0.5), + bbox_coder=dict( + type='DeltaXYWHBBoxCoderForGLIP', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + language_model=dict(type='BertModel', name=lang_model_name), + train_cfg=dict( + assigner=dict( + type='ATSSAssigner', + topk=9, + iou_calculator=dict(type='BboxOverlaps2D_GLIP')), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +# dataset settings +train_pipeline = [ + dict( + type='LoadImageFromFile', + imdecode_backend='pillow', + backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='GTBoxSubOne_GLIP'), + dict( + type='RandomChoiceResize', + scales=[(1333, 480), (1333, 560), (1333, 640), (1333, 720), + (1333, 800)], + keep_ratio=True, + resize_type='FixScaleResize', + backend='pillow'), + dict(type='RandomFlip_GLIP', prob=0.5), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities')) +] + +test_pipeline = [ + dict( + type='LoadImageFromFile', + backend_args=_base_.backend_args, + imdecode_backend='pillow'), + dict( + type='FixScaleResize', + scale=(800, 1333), + keep_ratio=True, + backend='pillow'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text', 'custom_entities')) +] + +train_dataloader = dict( + dataset=dict( + _delete_=True, + type='RepeatDataset', + times=2, + dataset=dict( + type=_base_.dataset_type, + data_root=_base_.data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + return_classes=True, + backend_args=_base_.backend_args))) + +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, return_classes=True)) +test_dataloader = val_dataloader + +# We did not adopt the official 24e optimizer strategy +# because the results indicate that the current strategy is superior. +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=0.00002, betas=(0.9, 0.999), weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + }), + clip_grad=None) diff --git a/configs/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py b/configs/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py new file mode 100644 index 00000000000..3487de3f3a2 --- /dev/null +++ b/configs/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py @@ -0,0 +1,9 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py' + +model = dict(bbox_head=dict(early_fuse=True, use_checkpoint=True)) + +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth' # noqa + +optim_wrapper = dict( + optimizer=dict(lr=0.00001), + clip_grad=dict(_delete_=True, max_norm=1, norm_type=2)) diff --git a/configs/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py b/configs/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py new file mode 100644 index 00000000000..5c315e490e7 --- /dev/null +++ b/configs/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py @@ -0,0 +1,3 @@ +_base_ = './glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py' + +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth' # noqa diff --git a/configs/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py b/configs/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py new file mode 100644 index 00000000000..3391272e608 --- /dev/null +++ b/configs/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py @@ -0,0 +1,3 @@ +_base_ = './glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py' + +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth' # noqa diff --git a/configs/glip/metafile.yml b/configs/glip/metafile.yml index 588d1c8d6b8..6fc245604aa 100644 --- a/configs/glip/metafile.yml +++ b/configs/glip/metafile.yml @@ -64,3 +64,48 @@ Models: Metrics: box AP: 51.3 Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth + - Name: glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco + In Collection: GLIP + Config: configs/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 53.1 + Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230908_091856-39f01d03.pth + - Name: glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco + In Collection: GLIP + Config: configs/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 54.1 + Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175354-e0c0c6d7.pth + - Name: glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco + In Collection: GLIP + Config: configs/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 55.2 + Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175543-5fcb4b97.pth + - Name: glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco + In Collection: GLIP + Config: configs/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 55.2 + Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_125111-ad1025a0.pth + - Name: glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco + In Collection: GLIP + Config: configs/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 59.4 + Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_100800-e9be4274.pth diff --git a/docs/en/user_guides/inference.md b/docs/en/user_guides/inference.md index 3a531ede1d6..49186d23695 100644 --- a/docs/en/user_guides/inference.md +++ b/docs/en/user_guides/inference.md @@ -395,10 +395,10 @@ Demo result will be similar to this: -If users would like to detect multiple targets, please declare them in the format of `xx . xx .` after the `--texts`. +If users would like to detect multiple targets, please declare them in the format of `xx. xx` after the `--texts`. ```shell -python demo/image_demo.py demo/demo.jpg glip_tiny_a_mmdet-b3654169.pth --texts 'bench . car .' +python demo/image_demo.py demo/demo.jpg glip_tiny_a_mmdet-b3654169.pth --texts 'bench. car' ``` And the result will be like this one: @@ -438,20 +438,3 @@ python tools/test.py configs/glip/glip_atss_swin-t_fpn_dyhead_pretrain_obj365.py # 8 GPU ./tools/dist_test.sh configs/glip/glip_atss_swin-t_fpn_dyhead_pretrain_obj365.py glip_tiny_a_mmdet-b3654169.pth 8 ``` - -The result will be similar to this: - -```shell -Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.428 -Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.594 -Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.466 -Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.300 -Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.477 -Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.534 -Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.634 -Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.634 -Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.634 -Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.473 -Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.690 -Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.789 -``` diff --git a/docs/zh_cn/user_guides/inference.md b/docs/zh_cn/user_guides/inference.md index 206e3bfde59..a0fb08faeb0 100644 --- a/docs/zh_cn/user_guides/inference.md +++ b/docs/zh_cn/user_guides/inference.md @@ -393,10 +393,10 @@ demo 效果如下图所示: -如果想进行多种类型的识别,需要使用 `xx . xx .` 的格式在 `--texts` 字段后声明目标类型: +如果想进行多种类型的识别,需要使用 `xx. xx` 的格式在 `--texts` 字段后声明目标类型: ```shell -python demo/image_demo.py demo/demo.jpg glip_tiny_a_mmdet-b3654169.pth --texts 'bench . car .' +python demo/image_demo.py demo/demo.jpg glip_tiny_a_mmdet-b3654169.pth --texts 'bench. car' ``` 结果如下图所示: @@ -436,20 +436,3 @@ python tools/test.py configs/glip/glip_atss_swin-t_fpn_dyhead_pretrain_obj365.py # 8 GPU ./tools/dist_test.sh configs/glip/glip_atss_swin-t_fpn_dyhead_pretrain_obj365.py glip_tiny_a_mmdet-b3654169.pth 8 ``` - -验证结果大致如下: - -```shell -Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.428 -Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.594 -Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.466 -Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.300 -Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.477 -Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.534 -Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.634 -Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.634 -Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.634 -Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.473 -Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.690 -Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.789 -``` diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py index b5ab3758382..1f30d6c1352 100644 --- a/mmdet/datasets/transforms/__init__.py +++ b/mmdet/datasets/transforms/__init__.py @@ -13,6 +13,7 @@ LoadEmptyAnnotations, LoadImageFromNDArray, LoadMultiChannelImageFromFiles, LoadPanopticAnnotations, LoadProposals, LoadTrackAnnotations) +from .transformers_glip import GTBoxSubOne_GLIP, RandomFlip_GLIP from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut, Expand, FixScaleResize, FixShapeResize, MinIoURandomCrop, MixUp, Mosaic, Pad, @@ -37,5 +38,6 @@ 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', - 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', 'ResizeShortestEdge' + 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', + 'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP' ] diff --git a/mmdet/datasets/transforms/transformers_glip.py b/mmdet/datasets/transforms/transformers_glip.py new file mode 100644 index 00000000000..60c4f87d1b8 --- /dev/null +++ b/mmdet/datasets/transforms/transformers_glip.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import HorizontalBoxes, autocast_box_type +from .transforms import RandomFlip + + +@TRANSFORMS.register_module() +class GTBoxSubOne_GLIP(BaseTransform): + """Subtract 1 from the x2 and y2 coordinates of the gt_bboxes.""" + + def transform(self, results: dict) -> dict: + if 'gt_bboxes' in results: + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, np.ndarray): + gt_bboxes[:, 2:] -= 1 + results['gt_bboxes'] = gt_bboxes + elif isinstance(gt_bboxes, HorizontalBoxes): + gt_bboxes = results['gt_bboxes'].tensor + gt_bboxes[:, 2:] -= 1 + results['gt_bboxes'] = HorizontalBoxes(gt_bboxes) + else: + raise NotImplementedError + return results + + +@TRANSFORMS.register_module() +class RandomFlip_GLIP(RandomFlip): + """Flip the image & bboxes & masks & segs horizontally or vertically. + + When using horizontal flipping, the corresponding bbox x-coordinate needs + to be additionally subtracted by one. + """ + + @autocast_box_type() + def _flip(self, results: dict) -> None: + """Flip images, bounding boxes, and semantic segmentation map.""" + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + img_shape = results['img'].shape[:2] + + # flip bboxes + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'].flip_(img_shape, results['flip_direction']) + # Only change this line + if results['flip_direction'] == 'horizontal': + results['gt_bboxes'].translate_([-1, 0]) + + # TODO: check it + # flip masks + if results.get('gt_masks', None) is not None: + results['gt_masks'] = results['gt_masks'].flip( + results['flip_direction']) + + # flip segs + if results.get('gt_seg_map', None) is not None: + results['gt_seg_map'] = mmcv.imflip( + results['gt_seg_map'], direction=results['flip_direction']) + + # record homography matrix for flip + self._record_homography_matrix(results) diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index 97b3f636934..4ac2bf75b54 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -7,7 +7,6 @@ import cv2 import mmcv -import numpy import numpy as np from mmcv.image import imresize from mmcv.image.geometric import _scale_size diff --git a/mmdet/models/dense_heads/atss_head.py b/mmdet/models/dense_heads/atss_head.py index fcccc2fef92..2ce71b3eff5 100644 --- a/mmdet/models/dense_heads/atss_head.py +++ b/mmdet/models/dense_heads/atss_head.py @@ -281,7 +281,7 @@ def loss_by_feat( Returns: dict[str, Tensor]: A dictionary of loss components. """ - featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device diff --git a/mmdet/models/dense_heads/atss_vlfusion_head.py b/mmdet/models/dense_heads/atss_vlfusion_head.py index 5dadc4c4975..c5cd28b4a04 100644 --- a/mmdet/models/dense_heads/atss_vlfusion_head.py +++ b/mmdet/models/dense_heads/atss_vlfusion_head.py @@ -20,9 +20,10 @@ from mmdet.registry import MODELS from mmdet.structures.bbox import cat_boxes -from mmdet.utils import InstanceList +from mmdet.utils import InstanceList, OptInstanceList, reduce_mean from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk, - permute_and_flatten, select_single_mlvl) + permute_and_flatten, select_single_mlvl, + unpack_gt_instances) from ..utils.vlfuse_helper import MAX_CLAMP_VALUE from .atss_head import ATSSHead @@ -389,8 +390,9 @@ def __init__(self, use_checkpoint: bool = False, num_dyhead_blocks: int = 6, lang_model_name: str = 'bert-base-uncased', + init_cfg=None, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs, init_cfg=init_cfg) self.head = VLFusionModule( in_channels=self.in_channels, feat_channels=self.feat_channels, @@ -399,6 +401,7 @@ def __init__(self, use_checkpoint=use_checkpoint, num_dyhead_blocks=num_dyhead_blocks, lang_model_name=lang_model_name) + self.text_masks = None def _init_layers(self) -> None: """No need to initialize the ATSS head layer.""" @@ -409,7 +412,309 @@ def forward(self, visual_feats: Tuple[Tensor], """Forward function.""" bbox_preds, centerness, cls_logits = self.head(visual_feats, language_feats) - return bbox_preds, centerness, cls_logits + return cls_logits, bbox_preds, centerness + + def loss(self, visual_feats: Tuple[Tensor], language_feats: dict, + batch_data_samples): + outputs = unpack_gt_instances(batch_data_samples) + (batch_gt_instances, batch_gt_instances_ignore, + batch_img_metas) = outputs + + outs = self(visual_feats, language_feats) + self.text_masks = language_feats['masks'] + loss_inputs = outs + (batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + centernesses: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + centernesses (list[Tensor]): Centerness for each scale + level with shape (N, num_anchors * 1, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor) = cls_reg_targets + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + anchors = torch.cat(anchor_list, dim=1) + labels = torch.cat(labels_list, dim=1) + label_weights = torch.cat(label_weights_list, dim=1) + bbox_targets = torch.cat(bbox_targets_list, dim=1) + cls_scores = torch.cat(cls_scores, dim=1) + + centernesses_ = [] + bbox_preds_ = [] + for bbox_pred, centerness in zip(bbox_preds, centernesses): + centernesses_.append( + centerness.permute(0, 2, 3, + 1).reshape(cls_scores.size(0), -1, 1)) + bbox_preds_.append( + bbox_pred.permute(0, 2, 3, + 1).reshape(cls_scores.size(0), -1, 4)) + bbox_preds = torch.cat(bbox_preds_, dim=1) + centernesses = torch.cat(centernesses_, dim=1) + + losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = \ + self._loss_by_feat( + anchors, + cls_scores, + bbox_preds, + centernesses, + labels, + label_weights, + bbox_targets, + avg_factor=avg_factor) + + bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() + losses_bbox = losses_bbox / bbox_avg_factor + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_centerness=loss_centerness) + + def _loss_by_feat(self, anchors: Tensor, cls_score: Tensor, + bbox_pred: Tensor, centerness: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + avg_factor: float) -> dict: + """Calculate the loss of all scale level based on the features + extracted by the detection head. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + anchors = anchors.reshape(-1, 4) + + # ===== this change ===== + pos_inds = (labels.sum(-1) > 0).reshape(-1) + + # Loss is not computed for the padded regions of the text. + assert (self.text_masks.dim() == 2) + text_mask = (self.text_masks > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, cls_score.size(1), 1) + cls_score = torch.masked_select(cls_score, text_mask).contiguous() + labels = torch.masked_select(labels, text_mask) + label_weights = label_weights[..., + None].repeat(1, 1, text_mask.size(-1)) + label_weights = torch.masked_select(label_weights, text_mask) + + bbox_pred = bbox_pred.reshape(-1, 4) + centerness = centerness.reshape(-1) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + # classification loss + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=avg_factor) + + if pos_inds.sum() > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_centerness = centerness[pos_inds] + + centerness_targets = self.centerness_target( + pos_anchors, pos_bbox_targets) + + if torch.isnan(centerness_targets).any(): + print('=====Centerness includes NaN=====') + mask = ~torch.isnan(centerness_targets) + centerness_targets = centerness_targets[mask] + pos_centerness = pos_centerness[mask] + pos_anchors = pos_anchors[mask] + pos_bbox_targets = pos_bbox_targets[mask] + pos_bbox_pred = pos_bbox_pred[mask] + + if pos_bbox_targets.shape[0] == 0: + loss_bbox = bbox_pred.sum() * 0 + loss_centerness = centerness.sum() * 0 + centerness_targets = bbox_targets.new_tensor(0.) + return loss_cls, loss_bbox, loss_centerness, \ + centerness_targets.sum() + + # The decoding process takes the offset into consideration. + pos_anchors[:, 2:] += 1 + pos_decode_bbox_pred = self.bbox_coder.decode( + pos_anchors, pos_bbox_pred) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_bbox_targets, + weight=centerness_targets, + avg_factor=1.0) + + # centerness loss + loss_centerness = self.loss_centerness( + pos_centerness, centerness_targets, avg_factor=avg_factor) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_centerness = centerness.sum() * 0 + centerness_targets = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() + + def _get_targets_single(self, + flat_anchors: Tensor, + valid_flags: Tensor, + num_level_anchors: List[int], + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + num_level_anchors (List[int]): Number of anchors of each scale + level. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4) + pos_inds (Tensor): Indices of positive anchor with shape + (num_pos,). + neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + sampling_result (:obj:`SamplingResult`): Sampling results. + """ + anchors = flat_anchors + # Align the official implementation + anchors[:, 2:] -= 1 + + num_level_anchors_inside = num_level_anchors + pred_instances = InstanceData(priors=anchors) + assign_result = self.assigner.assign(pred_instances, + num_level_anchors_inside, + gt_instances, gt_instances_ignore) + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + + # ===== this change ===== + labels = anchors.new_full((num_valid_anchors, self.feat_channels), + 0, + dtype=torch.float32) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if self.reg_decoded_bbox: + pos_bbox_targets = sampling_result.pos_gt_bboxes + else: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_priors, sampling_result.pos_gt_bboxes) + + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + # ===== this change ===== + labels[pos_inds] = gt_instances.positive_maps[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds, sampling_result) + + def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor: + """Calculate the centerness between anchors and gts. + + Only calculate pos centerness targets, otherwise there may be nan. + + Args: + anchors (Tensor): Anchors with shape (N, 4), "xyxy" format. + gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Centerness between anchors and gts. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + l_ = anchors_cx - gts[:, 0] + t_ = anchors_cy - gts[:, 1] + r_ = gts[:, 2] - anchors_cx + b_ = gts[:, 3] - anchors_cy + + left_right = torch.stack([l_, r_], dim=1) + top_bottom = torch.stack([t_, b_], dim=1) + centerness = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) + # assert not torch.isnan(centerness).any() + return centerness def predict(self, visual_feats: Tuple[Tensor], @@ -450,9 +755,9 @@ def predict(self, return predictions def predict_by_feat(self, + cls_logits: List[Tensor], bbox_preds: List[Tensor], score_factors: List[Tensor], - cls_logits: List[Tensor], batch_img_metas: Optional[List[dict]] = None, batch_token_positive_maps: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, @@ -466,15 +771,15 @@ def predict_by_feat(self, such as CenterNess in FCOS, IoU branch in ATSS. Args: + cls_logits (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). score_factors (list[Tensor], optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, num_priors * 1, H, W). Defaults to None. - cls_logits (list[Tensor]): Classification scores for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_priors * num_classes, H, W). batch_img_metas (list[dict], Optional): Batch image meta info. Defaults to None. batch_token_positive_maps (list[dict], Optional): Batch token diff --git a/mmdet/models/detectors/glip.py b/mmdet/models/detectors/glip.py index 5f7212f7f40..9b96eda6112 100644 --- a/mmdet/models/detectors/glip.py +++ b/mmdet/models/detectors/glip.py @@ -2,7 +2,7 @@ import copy import re import warnings -from typing import Tuple +from typing import Tuple, Union import torch from torch import Tensor @@ -206,38 +206,41 @@ def __init__(self, self.language_model = MODELS.build(language_model) self._text_prompts = None - self._positive_maps = None + self._token_positive_maps = None self._language_dict_features = None self._entities = None + self._special_tokens = '. ' - def get_tokens_positive_and_prompts( + def get_tokens_and_prompts( self, - original_caption: str, - custom_entities: bool = False) -> Tuple[dict, str]: + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, list]: """Get the tokens positive and prompts for the caption.""" if isinstance(original_caption, (list, tuple)) or custom_entities: if custom_entities and isinstance(original_caption, str): - if not original_caption.endswith('.'): - original_caption = original_caption + ' . ' - original_caption = original_caption.split(' . ') + if original_caption.endswith(self._special_tokens): + original_caption = original_caption.replace( + self._special_tokens, '') + original_caption = original_caption.split(self._special_tokens) original_caption = list( filter(lambda x: len(x) > 0, original_caption)) caption_string = '' tokens_positive = [] - seperation_tokens = ' . ' - for word in original_caption: + for idx, word in enumerate(original_caption): tokens_positive.append( [[len(caption_string), len(caption_string) + len(word)]]) caption_string += word - caption_string += seperation_tokens + if idx != len(original_caption) - 1: + caption_string += self._special_tokens tokenized = self.language_model.tokenizer([caption_string], return_tensors='pt') self._entities = original_caption else: - if not original_caption.endswith('.'): - original_caption = original_caption + ' . ' + if original_caption.endswith(self._special_tokens): + original_caption = original_caption.replace( + self._special_tokens, '') tokenized = self.language_model.tokenizer([original_caption], return_tensors='pt') @@ -245,10 +248,78 @@ def get_tokens_positive_and_prompts( self._entities = noun_phrases caption_string = original_caption + return tokenized, caption_string, tokens_positive + + def get_positive_map(self, tokenized, tokens_positive): positive_map = create_positive_map(tokenized, tokens_positive) positive_map_label_to_token = create_positive_map_label_to_token( positive_map, plus=1) - return positive_map_label_to_token, caption_string + return positive_map_label_to_token, positive_map + + def get_tokens_positive_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, Tensor]: + tokenized, caption_string, tokens_positive = \ + self.get_tokens_and_prompts( + original_caption, custom_entities) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + return positive_map_label_to_token, caption_string, positive_map + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + # TODO: Only open vocabulary tasks are supported for training now. + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + + gt_labels = [ + data_samples.gt_instances.labels + for data_samples in batch_data_samples + ] + + new_text_prompts = [] + positive_maps = [] + if len(set(text_prompts)) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + tokenized, caption_string, tokens_positive = \ + self.get_tokens_and_prompts( + text_prompts[0], True) + new_text_prompts = [caption_string] * len(batch_inputs) + for gt_label in gt_labels: + new_tokens_positive = [ + tokens_positive[label] for label in gt_label + ] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive) + positive_maps.append(positive_map) + else: + for text_prompt, gt_label in zip(text_prompts, gt_labels): + tokenized, caption_string, tokens_positive = \ + self.get_tokens_and_prompts( + text_prompt, True) + new_tokens_positive = [ + tokens_positive[label] for label in gt_label + ] + _, positive_map = self.get_positive_map( + tokenized, new_tokens_positive) + positive_maps.append(positive_map) + new_text_prompts.append(caption_string) + + language_dict_features = self.language_model(new_text_prompts) + for i, data_samples in enumerate(batch_data_samples): + # .bool().float() is very important + positive_map = positive_maps[i].to( + batch_inputs.device).bool().float() + data_samples.gt_instances.positive_maps = positive_map + + visual_features = self.extract_feat(batch_inputs) + + losses = self.bbox_head.loss(visual_features, language_dict_features, + batch_data_samples) + return losses def predict(self, batch_inputs: Tensor, @@ -307,12 +378,13 @@ def predict(self, for text_prompt in text_prompts ] - self._positive_maps, text_prompts = zip( + self._token_positive_maps, text_prompts, _ = zip( *_positive_maps_and_prompts) - self._language_dict_features = self.language_model(text_prompts) + self._language_dict_features = self.language_model( + list(text_prompts)) for i, data_samples in enumerate(batch_data_samples): - data_samples.token_positive_map = self._positive_maps[i] + data_samples.token_positive_map = self._token_positive_maps[i] visual_features = self.extract_feat(batch_inputs) diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py index 2194780c853..343fbfefbd8 100644 --- a/mmdet/models/necks/__init__.py +++ b/mmdet/models/necks/__init__.py @@ -8,6 +8,7 @@ from .fpg import FPG from .fpn import FPN from .fpn_carafe import FPN_CARAFE +from .fpn_dropblock import FPN_DropBlock from .hrfpn import HRFPN from .nas_fpn import NASFPN from .nasfcos_fpn import NASFCOS_FPN @@ -21,5 +22,6 @@ __all__ = [ 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder', - 'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'CSPNeXtPAFPN', 'SSH' + 'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead', 'CSPNeXtPAFPN', 'SSH', + 'FPN_DropBlock' ] diff --git a/mmdet/models/necks/fpn_dropblock.py b/mmdet/models/necks/fpn_dropblock.py new file mode 100644 index 00000000000..473af924cda --- /dev/null +++ b/mmdet/models/necks/fpn_dropblock.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch.nn.functional as F +from torch import Tensor + +from mmdet.registry import MODELS +from .fpn import FPN + + +@MODELS.register_module() +class FPN_DropBlock(FPN): + + def __init__(self, + *args, + plugin: Optional[dict] = dict( + type='DropBlock', + drop_prob=0.3, + block_size=3, + warmup_iters=0), + **kwargs) -> None: + super().__init__(*args, **kwargs) + self.plugin = None + if plugin is not None: + self.plugin = MODELS.build(plugin) + + def forward(self, inputs: Tuple[Tensor]) -> tuple: + """Forward function. + + Args: + inputs (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + + Returns: + tuple: Feature maps, each is a 4D-tensor. + """ + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + if self.plugin is not None: + laterals[i - 1] = self.plugin(laterals[i - 1]) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/mmdet/models/task_modules/assigners/__init__.py b/mmdet/models/task_modules/assigners/__init__.py index a98b0ed499a..bd71020e56e 100644 --- a/mmdet/models/task_modules/assigners/__init__.py +++ b/mmdet/models/task_modules/assigners/__init__.py @@ -7,7 +7,7 @@ from .dynamic_soft_label_assigner import DynamicSoftLabelAssigner from .grid_assigner import GridAssigner from .hungarian_assigner import HungarianAssigner -from .iou2d_calculator import BboxOverlaps2D +from .iou2d_calculator import BboxOverlaps2D, BboxOverlaps2D_GLIP from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost, DiceCost, FocalLossCost, IoUCost) from .max_iou_assigner import MaxIoUAssigner @@ -26,5 +26,5 @@ 'TaskAlignedAssigner', 'TopkHungarianAssigner', 'BBoxL1Cost', 'ClassificationCost', 'CrossEntropyLossCost', 'DiceCost', 'FocalLossCost', 'IoUCost', 'BboxOverlaps2D', 'DynamicSoftLabelAssigner', - 'MultiInstanceAssigner' + 'MultiInstanceAssigner', 'BboxOverlaps2D_GLIP' ] diff --git a/mmdet/models/task_modules/assigners/iou2d_calculator.py b/mmdet/models/task_modules/assigners/iou2d_calculator.py index 0e85d1e422c..b6daa94feb4 100644 --- a/mmdet/models/task_modules/assigners/iou2d_calculator.py +++ b/mmdet/models/task_modules/assigners/iou2d_calculator.py @@ -66,3 +66,23 @@ def __repr__(self): repr_str = self.__class__.__name__ + f'(' \ f'scale={self.scale}, dtype={self.dtype})' return repr_str + + +@TASK_UTILS.register_module() +class BboxOverlaps2D_GLIP(BboxOverlaps2D): + + def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False): + TO_REMOVE = 1 + area1 = (bboxes1[:, 2] - bboxes1[:, 0] + TO_REMOVE) * ( + bboxes1[:, 3] - bboxes1[:, 1] + TO_REMOVE) + area2 = (bboxes2[:, 2] - bboxes2[:, 0] + TO_REMOVE) * ( + bboxes2[:, 3] - bboxes2[:, 1] + TO_REMOVE) + + lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [N,M,2] + rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + iou = inter / (area1[:, None] + area2 - inter) + return iou diff --git a/mmdet/models/utils/vlfuse_helper.py b/mmdet/models/utils/vlfuse_helper.py index e33c54de749..f6112bf5051 100644 --- a/mmdet/models/utils/vlfuse_helper.py +++ b/mmdet/models/utils/vlfuse_helper.py @@ -2,7 +2,7 @@ # Modified from https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/utils/fuse_helper.py # noqa # and https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/modeling/rpn/modeling_bert.py # noqa import math -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -94,7 +94,7 @@ def __init__(self, self.l_dim = l_dim assert ( - self.head_dim * self.num_heads == self.embed_dim + self.head_dim * self.num_heads == self.embed_dim ), 'embed_dim must be divisible by num_heads ' \ f'(got `embed_dim`: {self.embed_dim} ' \ f'and `num_heads`: {self.num_heads}).' @@ -288,13 +288,15 @@ def __init__(self, self.gamma_l = nn.Parameter( init_values * torch.ones(l_dim), requires_grad=True) - def forward( - self, - visual_features: Tuple[Tensor], - lang_feature: Tensor, - attention_mask_l: Optional[Tensor] = None - ) -> Tuple[List[Tensor], Tensor]: - + def forward(self, + vf0: Tensor, + vf1: Tensor, + vf2: Tensor, + vf3: Tensor, + vf4: Tensor, + lang_feature: Tensor, + attention_mask_l=None): + visual_features = [vf0, vf1, vf2, vf3, vf4] size_per_level, visual_features_flatten = [], [] for i, feat_per_level in enumerate(visual_features): bs, c, h, w = feat_per_level.shape @@ -310,15 +312,16 @@ def forward( new_v = new_v.transpose(1, 2).contiguous() start = 0 - fusion_visual_features = [] + # fvfs is mean fusion_visual_features + fvfs = [] for (h, w) in size_per_level: new_v_per_level = new_v[:, :, start:start + h * w].view(bs, -1, h, w).contiguous() - fusion_visual_features.append(new_v_per_level) + fvfs.append(new_v_per_level) start += h * w - return fusion_visual_features, new_lang_feature + return fvfs[0], fvfs[1], fvfs[2], fvfs[3], fvfs[4], new_lang_feature def single_attention_call( self, @@ -387,19 +390,23 @@ def forward(self, x: dict) -> dict: language_dict_features = x['lang'] if self.use_checkpoint: - fused_visual_features, language_features = checkpoint.checkpoint( - self.b_attn, visual_features, language_dict_features['hidden'], + # vf is mean visual_features + # checkpoint does not allow complex data structures as input, + # such as list, so we must split them. + vf0, vf1, vf2, vf3, vf4, language_features = checkpoint.checkpoint( + self.b_attn, *visual_features, + language_dict_features['hidden'], language_dict_features['masks']) else: - fused_visual_features, language_features = self.b_attn( - visual_features, language_dict_features['hidden'], + vf0, vf1, vf2, vf3, vf4, language_features = self.b_attn( + *visual_features, language_dict_features['hidden'], language_dict_features['masks']) language_dict_features['hidden'] = language_features fused_language_dict_features = language_dict_features features_dict = { - 'visual': fused_visual_features, + 'visual': [vf0, vf1, vf2, vf3, vf4], 'lang': fused_language_dict_features }