Skip to content

Commit

Permalink
Support GLIP Funetune (#10866)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid authored Sep 11, 2023
1 parent af816d3 commit 82d2a6e
Show file tree
Hide file tree
Showing 20 changed files with 864 additions and 97 deletions.
31 changes: 20 additions & 11 deletions configs/glip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,34 @@ wget https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b365

python demo/image_demo.py demo/demo.jpg \
configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py \
glip_tiny_a_mmdet-b3654169.pth \
--texts 'bench . car .'
--weights glip_tiny_a_mmdet-b3654169.pth \
--texts 'bench. car'
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/de370086-a5ae-4b77-8cbd-4592abf4afb0" width="40%"/>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/7b450d96-81ac-462a-92bc-0d4ae7b8721c" width="40%"/>
</div>

## Results and Models

| Model | Zero-shot or Funetune | COCO mAP | Pre-Train Data | Config | Download |
| :--------: | :-------------------: | :------: | :------------------------: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------------------------: |
| GLIP-T (A) | Zero-shot | 43.0 | O365 | [config](glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) |
| GLIP-T (B) | Zero-shot | 44.9 | O365 | [config](glip_atss_swin-t_b_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) |
| GLIP-T (C) | Zero-shot | 46.7 | O365,GoldG | [config](glip_atss_swin-t_c_fpn_dyhead_pretrain_obj365-goldg.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) |
| GLIP-T | Zero-shot | 46.4 | O365,GoldG,CC3M,SBU | [config](glip_atss_swin-t_fpn_dyhead_pretrain_obj365-goldg-cc3m-sub.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) |
| GLIP-L | Zero-shot | 51.3 | FourODs,GoldG,CC3M+12M,SBU | [config](glip_atss_swin-l_fpn_dyhead_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) |
| Model | Zero-shot or Funetune | COCO mAP | Pre-Train Data | Config | Download |
| :--------: | :-------------------: | :------: | :------------------------: | :---------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| GLIP-T (A) | Zero-shot | 43.0 | O365 | [config](glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) |
| GLIP-T (A) | Funetune | 53.1 | O365 | [config](glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230908_091856-39f01d03.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230908_091856.log.json) |
| GLIP-T (B) | Zero-shot | 44.9 | O365 | [config](glip_atss_swin-t_b_fpn_dyhead_pretrain_obj365.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) |
| GLIP-T (B) | Funetune | 54.1 | O365 | [config](glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175354-e0c0c6d7.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175354.log.json) |
| GLIP-T (C) | Zero-shot | 46.7 | O365,GoldG | [config](glip_atss_swin-t_c_fpn_dyhead_pretrain_obj365-goldg.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) |
| GLIP-T (C) | Funetune | 55.2 | O365,GoldG | [config](glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175543-5fcb4b97.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175543.log.json) |
| GLIP-T | Zero-shot | 46.4 | O365,GoldG,CC3M,SBU | [config](glip_atss_swin-t_fpn_dyhead_pretrain_obj365-goldg-cc3m-sub.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) |
| GLIP-T | Funetune | 55.2 | O365,GoldG,CC3M,SBU | [config](glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_125111-ad1025a0.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_125111.log.json) |
| GLIP-L | Zero-shot | 51.3 | FourODs,GoldG,CC3M+12M,SBU | [config](glip_atss_swin-l_fpn_dyhead_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) |
| GLIP-L | Funetune | 59.4 | FourODs,GoldG,CC3M+12M,SBU | [config](glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_100800-e9be4274.pth)\| [log](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_100800.log.json) |

Note:

1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/glip_to_mmdet.py). We have not retrained the model for the time being.
2. We will soon support fine-tuning on COCO.
2. Funetune refers to fine-tuning on the COCO 2017 dataset. The L model is trained using 16 A100 GPUs, while the remaining models are trained using 16 NVIDIA GeForce 3090 GPUs.
3. Taking the GLIP-T(A) model as an example, I trained it twice using the official code, and the fine-tuning mAP were 52.5 and 52.6. Therefore, the mAP we achieved in our reproduction is higher than the official results. The main reason is that we modified the `weight_decay` parameter.
4. Our experiments revealed that training for 24 epochs leads to overfitting. Therefore, we chose the best-performing model. If users want to train on a custom dataset, it is advisable to shorten the number of epochs and save the best-performing model.
5. Due to the official absence of fine-tuning hyperparameters for the GLIP-L model, we have not yet reproduced the official accuracy. I have found that overfitting can also occur, so it may be necessary to consider custom modifications to data augmentation and model enhancement. Given the high cost of training, we have not conducted any research on this matter at the moment.
6. We noticed that there is a discrepancy between the performance evaluation of the checkpoint and the evaluation logs during training. This is because the buffers of different ranks are not the same during training, but we only saved the weights of rank 0. If you want to avoid this issue, you can add the parameter `broadcast_buffers=True` in the configuration.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = './glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py'

model = dict(
backbone=dict(
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
drop_path_rate=0.4,
),
neck=dict(in_channels=[384, 768, 1536]),
bbox_head=dict(early_fuse=True, num_dyhead_blocks=8, use_checkpoint=True))

load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth' # noqa
155 changes: 155 additions & 0 deletions configs/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth' # noqa
lang_model_name = 'bert-base-uncased'

model = dict(
type='GLIP',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
bgr_to_rgb=False,
pad_size_divisor=32),
backbone=dict(
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
with_cp=False,
convert_weights=False),
neck=dict(
type='FPN_DropBlock',
in_channels=[192, 384, 768],
out_channels=256,
start_level=0,
relu_before_extra_convs=True,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='ATSSVLFusionHead',
lang_model_name=lang_model_name,
num_classes=80,
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128],
center_offset=0.5),
bbox_coder=dict(
type='DeltaXYWHBBoxCoderForGLIP',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
language_model=dict(type='BertModel', name=lang_model_name),
train_cfg=dict(
assigner=dict(
type='ATSSAssigner',
topk=9,
iou_calculator=dict(type='BboxOverlaps2D_GLIP')),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))

# dataset settings
train_pipeline = [
dict(
type='LoadImageFromFile',
imdecode_backend='pillow',
backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='GTBoxSubOne_GLIP'),
dict(
type='RandomChoiceResize',
scales=[(1333, 480), (1333, 560), (1333, 640), (1333, 720),
(1333, 800)],
keep_ratio=True,
resize_type='FixScaleResize',
backend='pillow'),
dict(type='RandomFlip_GLIP', prob=0.5),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction', 'text',
'custom_entities'))
]

test_pipeline = [
dict(
type='LoadImageFromFile',
backend_args=_base_.backend_args,
imdecode_backend='pillow'),
dict(
type='FixScaleResize',
scale=(800, 1333),
keep_ratio=True,
backend='pillow'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'custom_entities'))
]

train_dataloader = dict(
dataset=dict(
_delete_=True,
type='RepeatDataset',
times=2,
dataset=dict(
type=_base_.dataset_type,
data_root=_base_.data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline,
return_classes=True,
backend_args=_base_.backend_args)))

val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, return_classes=True))
test_dataloader = val_dataloader

# We did not adopt the official 24e optimizer strategy
# because the results indicate that the current strategy is superior.
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00002, betas=(0.9, 0.999), weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}),
clip_grad=None)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = './glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py'

model = dict(bbox_head=dict(early_fuse=True, use_checkpoint=True))

load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth' # noqa

optim_wrapper = dict(
optimizer=dict(lr=0.00001),
clip_grad=dict(_delete_=True, max_norm=1, norm_type=2))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = './glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py'

load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth' # noqa
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = './glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py'

load_from = 'https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth' # noqa
45 changes: 45 additions & 0 deletions configs/glip/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,48 @@ Models:
Metrics:
box AP: 51.3
Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth
- Name: glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco
In Collection: GLIP
Config: configs/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco.py
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 53.1
Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_a_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230908_091856-39f01d03.pth
- Name: glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco
In Collection: GLIP
Config: configs/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco.py
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 54.1
Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_b_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175354-e0c0c6d7.pth
- Name: glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco
In Collection: GLIP
Config: configs/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco.py
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 55.2
Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_c_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230909_175543-5fcb4b97.pth
- Name: glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco
In Collection: GLIP
Config: configs/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco.py
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 55.2
Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-t_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_125111-ad1025a0.pth
- Name: glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco
In Collection: GLIP
Config: configs/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco.py
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 59.4
Weights: https://download.openmmlab.com/mmdetection/v3.0/glip/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco/glip_atss_swin-l_fpn_dyhead_16xb2_ms-2x_funtune_coco_20230910_100800-e9be4274.pth
Loading

0 comments on commit 82d2a6e

Please sign in to comment.