forked from open-mmlab/mmyolo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
yolox_s_fast_8xb8-300e_coco.py
331 lines (308 loc) · 10.1 KB
/
yolox_s_fast_8xb8-300e_coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
_base_ = ['../_base_/default_runtime.py', 'yolox_p5_tta.py']
# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of train image path
num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 8
# Worker to pre-fetch data for each single GPU during tarining
train_num_workers = 8
# Presistent_workers must be False if num_workers is 0
persistent_workers = True
# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
model_test_cfg = dict(
yolox_style=True, # better
# The config of multi-label for multi-class prediction
multi_label=True, # 40.5 -> 40.7
score_thr=0.001, # Threshold to filter out boxes
max_per_img=300, # Max number of detections of each image
nms=dict(type='nms', iou_threshold=0.65)) # NMS type and threshold
# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)
# generate new random resize shape interval
batch_augments_interval = 10
# -----train val related-----
weight_decay = 0.0005
loss_cls_weight = 1.0
loss_bbox_weight = 5.0
loss_obj_weight = 1.0
loss_bbox_aux_weight = 1.0
center_radius = 2.5 # SimOTAAssigner
num_last_epochs = 15
random_affine_scaling_ratio_range = (0.1, 2)
mixup_ratio_range = (0.8, 1.6)
# Save model checkpoint and validation intervals
save_epoch_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
ema_momentum = 0.0001
# ===============================Unmodified in most cases====================
# model settings
model = dict(
type='YOLODetector',
init_cfg=dict(
type='Kaiming',
layer='Conv2d',
a=2.23606797749979, # math.sqrt(5)
distribution='uniform',
mode='fan_in',
nonlinearity='leaky_relu'),
# TODO: Waiting for mmengine support
use_syncbn=False,
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
pad_size_divisor=32,
batch_augments=[
dict(
type='YOLOXBatchSyncRandomResize',
random_size_range=(480, 800),
size_divisor=32,
interval=batch_augments_interval)
]),
backbone=dict(
type='YOLOXCSPDarknet',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
out_indices=(2, 3, 4),
spp_kernal_sizes=(5, 9, 13),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True),
),
neck=dict(
type='YOLOXPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=256,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOXHead',
head_module=dict(
type='YOLOXHeadModule',
num_classes=num_classes,
in_channels=256,
feat_channels=256,
widen_factor=widen_factor,
stacked_convs=2,
featmap_strides=(8, 16, 32),
use_depthwise=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True),
),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=loss_cls_weight),
loss_bbox=dict(
type='mmdet.IoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=loss_bbox_weight),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=loss_obj_weight),
loss_bbox_aux=dict(
type='mmdet.L1Loss',
reduction='sum',
loss_weight=loss_bbox_aux_weight)),
train_cfg=dict(
assigner=dict(
type='mmdet.SimOTAAssigner',
center_radius=center_radius,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'))),
test_cfg=model_test_cfg)
pre_transform = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True)
]
train_pipeline_stage1 = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='mmdet.RandomAffine',
scaling_ratio_range=random_affine_scaling_ratio_range,
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='YOLOXMixUp',
img_scale=img_scale,
ratio_range=mixup_ratio_range,
pad_val=114.0,
pre_transform=pre_transform),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.FilterAnnotations',
min_gt_bbox_wh=(1, 1),
keep_empty=False),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_pipeline_stage2 = [
*pre_transform,
dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
dict(
type='mmdet.Pad',
pad_to_square=True,
# If the image is three-channel, the pad value needs
# to be set separately for each channel.
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.FilterAnnotations',
min_gt_bbox_wh=(1, 1),
keep_empty=False),
dict(type='mmdet.PackDetInputs')
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
collate_fn=dict(type='yolov5_collate'),
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline_stage1))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
dict(
type='mmdet.Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=val_ann_file,
data_prefix=dict(img=val_data_prefix),
test_mode=True,
pipeline=test_pipeline))
test_dataloader = val_dataloader
# Reduce evaluation time
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator
# optimizer
# default 8 gpu
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.9,
weight_decay=weight_decay,
nesterov=True),
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
# learning rate
param_scheduler = [
dict(
# use quadratic formula to warm up 5 epochs
# and lr is updated by iteration
# TODO: fix default scope in get function
type='mmdet.QuadraticWarmupLR',
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
# use cosine lr from 5 to 285 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=5,
T_max=max_epochs - num_last_epochs,
end=max_epochs - num_last_epochs,
by_epoch=True,
convert_to_iter_based=True),
dict(
# use fixed lr during last 15 epochs
type='ConstantLR',
by_epoch=True,
factor=1,
begin=max_epochs - num_last_epochs,
end=max_epochs,
)
]
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=save_epoch_intervals,
max_keep_ckpts=max_keep_ckpts,
save_best='auto'))
custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
new_train_pipeline=train_pipeline_stage2,
priority=48),
dict(type='mmdet.SyncNormHook', priority=48),
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=ema_momentum,
update_buffers=True,
strict_load=False,
priority=49)
]
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
auto_scale_lr = dict(base_batch_size=8 * train_batch_size_per_gpu)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')