Skip to content

Commit 562e783

Browse files
committed
handle Timm pretrained backbones
1 parent 88ee5a0 commit 562e783

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

icevision/models/mmdet/common/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,30 @@ def build_model(
6060
cfg.model.roi_head.bbox_head.num_classes = num_classes - 1
6161
cfg.model.roi_head.mask_head.num_classes = num_classes - 1
6262

63-
if pretrained == True:
64-
cfg.model.pretrained = True
63+
# When using Timm backbone, loading the pretained weights are done icevision and not by mmdet code
64+
# Check out here below
65+
# Set cfg.model.pretrained to avoid mmdet loading them
66+
if pretrained == True and (
67+
isinstance(backbone, MMDetTimmBackboneConfig) and backbone.pretrained == True
68+
):
69+
cfg.model.pretrained = None # Timm pretrained backbones
6570
elif (pretrained == False) or (weights_path is not None):
6671
cfg.model.pretrained = None
6772

6873
_model = build_detector(cfg.model, cfg.get("train_cfg"), cfg.get("test_cfg"))
6974
_model.init_weights()
7075

76+
# Load pretrained weights either:
77+
# - by loading the whole pretrained model (COCO in general)
78+
# - or only the pretrained backbone like Timm ones
7179
if pretrained:
7280
if weights_path is not None:
7381
print(
7482
f"loading pretrained weights from user-provided url: {backbone.weights_url}"
7583
)
7684
load_checkpoint(_model, str(weights_path))
85+
# We handle loading Timm (backbone) pretrained weights here
86+
# the weights_url are stored in the backbone dict
7787
elif _model.backbone.weights_url is not None:
7888
weights_url = _model.backbone.weights_url
7989
print(f"loading default pretrained weights: {weights_url}")

0 commit comments

Comments
 (0)