-
Notifications
You must be signed in to change notification settings - Fork 2
/
voc_det.py
229 lines (196 loc) · 8.44 KB
/
voc_det.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""
Finetune a pre-trained model on a downstream task, one of those available in
Detectron2.
Supported downstream:
- LVIS Instance Segmentation
- COCO Instance Segmentation
- Pascal VOC 2007+12 Object Detection
Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py
Thanks to the developers of Detectron2!
"""
import argparse
import os
import re
from typing import Any, Dict, Union
import torch
from torch.utils.tensorboard import SummaryWriter
import detectron2 as d2
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultTrainer, default_setup
from detectron2.evaluation import (
LVISEvaluator,
PascalVOCDetectionEvaluator,
COCOEvaluator,
)
from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
from config import Config
from factories import PretrainingModelFactory
from utils.checkpointing import CheckpointManager
from utils.common import common_parser
import utils.distributed as dist
# fmt: off
parser = common_parser(
description="Train object detectors from pretrained visual backbone."
)
parser.add_argument(
"--d2-config", required=True,
help="Path to a detectron2 config for downstream task finetuning."
)
parser.add_argument(
"--d2-config-override", nargs="*", default=[],
help="""Key-value pairs from Detectron2 config to override from file.
Some keys will be ignored because they are set from other args:
[DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD,
TEST.EVAL_PERIOD, OUTPUT_DIR]""",
)
parser.add_argument_group("Checkpointing and Logging")
parser.add_argument(
"--weight-init", choices=["random", "imagenet", "torchvision", "vlinfo"],
default="vlinfo", help="""How to initialize weights:
1. 'random' initializes all weights randomly
2. 'imagenet' initializes backbone weights from torchvision model zoo
3. {'torchvision', 'vlinfo'} load state dict from --checkpoint-path
- with 'torchvision', state dict would be from PyTorch's training
script.
- with 'vlinfo' it should be for our full pretrained model."""
)
parser.add_argument(
"--checkpoint-path",
help="Path to load checkpoint and run downstream task evaluation."
)
parser.add_argument(
"--resume", action="store_true", help="""Specify this flag when resuming
training from a checkpoint saved by Detectron2."""
)
parser.add_argument(
"--eval-only", action="store_true",
help="Skip training and evaluate checkpoint provided at --checkpoint-path.",
)
parser.add_argument(
"--checkpoint-every", type=int, default=5000,
help="Serialize model to a checkpoint after every these many iterations.",
)
# fmt: on
@ROI_HEADS_REGISTRY.register()
class Res5ROIHeadsExtraNorm(Res5ROIHeads):
r"""
ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN
C4/DC5 backbones for VOC detection.
"""
def _build_res5_block(self, cfg):
seq, out_channels = super()._build_res5_block(cfg)
norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels)
seq.add_module("norm", norm)
return seq, out_channels
def build_detectron2_config(_C: Config, _A: argparse.Namespace):
r"""Build detectron2 config based on our pre-training config and args."""
_D2C = d2.config.get_cfg()
# Override some default values based on our config file.
_D2C.merge_from_file(_A.d2_config)
_D2C.merge_from_list(_A.d2_config_override)
# Set some config parameters from args.
_D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers
_D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every
_D2C.OUTPUT_DIR = _A.checkpoints_dir
# Set ResNet depth to override in Detectron2's config.
_D2C.MODEL.RESNETS.DEPTH = int(
re.search(r"resnet(\d+)", "torchvision::resnet50").group(1)
if "torchvision" in "torchvision::resnet50"
else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NETWORK_NAME).group(1)
if "detectron2" in _C.MODEL.VISUAL.NETWORK_NAME
else 0
)
return _D2C
class DownstreamTrainer(DefaultTrainer):
r"""
Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks.
Parameters
----------
cfg: detectron2.config.CfgNode
Detectron2 config object containing all config params.
weights: Union[str, Dict[str, Any]]
Weights to load in the initialized model. If ``str``, then we assume path
to a checkpoint, or if a ``dict``, we assume a state dict. This will be
an ``str`` only if we resume training from a Detectron2 checkpoint.
"""
def __init__(self, cfg, weights: Union[str, Dict[str, Any]]):
super().__init__(cfg)
# Load pre-trained weights before wrapping to DDP because `ApexDDP` has
# some weird issue with `DetectionCheckpointer`.
# fmt: off
if isinstance(weights, str):
# weights are ``str`` means ImageNet init or resume training.
self.start_iter = (
DetectionCheckpointer(
self._trainer.model,
optimizer=self._trainer.optimizer,
scheduler=self.scheduler
).resume_or_load(weights, resume=True).get("iteration", -1) + 1
)
elif isinstance(weights, dict):
# weights are a state dict means our pretrain init.
DetectionCheckpointer(self._trainer.model)._load_model(weights)
# fmt: on
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
elif evaluator_type == "coco":
return COCOEvaluator(dataset_name, cfg, True, output_folder)
elif evaluator_type == "lvis":
return LVISEvaluator(dataset_name, cfg, True, output_folder)
def test(self, cfg=None, model=None, evaluators=None):
r"""Evaluate the model and log results to stdout and tensorboard."""
cfg = cfg or self.cfg
model = model or self.model
tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
results = super().test(cfg, model)
flat_results = d2.evaluation.testing.flatten_results_dict(results)
for k, v in flat_results.items():
tensorboard_writer.add_scalar(k, v, self.start_iter)
def main(_A: argparse.Namespace):
# Get the current device as set for current distributed process.
# Check `launch` function in `vlinfo.utils.distributed` module.
device = torch.cuda.current_device()
# Local process group is needed for detectron2.
pg = list(range(dist.get_world_size()))
d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg)
# Create a config object (this will be immutable) and perform common setup
# such as logging and setting up serialization directory.
if _A.weight_init == "imagenet":
_A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
_C = Config(_A.config, _A.config_override)
# We use `default_setup` from detectron2 to do some common setup, such as
# logging, setting up serialization etc. For more info, look into source.
_D2C = build_detectron2_config(_C, _A)
default_setup(_D2C, _A)
# Prepare weights to pass in instantiation call of trainer.
model = PretrainingModelFactory.from_config(_C)
_ = CheckpointManager(model=model).load(_A.checkpoint_path)
weights = model.image_encoder.detectron2_backbone_state_dict()
# Back up pretrain config and model checkpoint (if provided).
_C.dump(os.path.join(_A.checkpoints_dir, "pretrain_config.yaml"))
if _A.weight_init == "vlinfo" and not _A.resume:
torch.save(
model.state_dict(),
os.path.join(_A.checkpoints_dir, "pretrain_model.pth"),
)
del model
trainer = DownstreamTrainer(_D2C, weights)
trainer.test() if _A.eval_only else trainer.train()
if __name__ == "__main__":
_A = parser.parse_args()
# This will launch `main` and set appropriate CUDA device (GPU ID) as
# per process (accessed in the beginning of `main`).
dist.launch(
main,
num_machines=_A.num_machines,
num_gpus_per_machine=_A.num_gpus_per_machine,
machine_rank=_A.machine_rank,
dist_url=_A.dist_url,
args=(_A,),
)