Skip to content

Commit 2544d3b

Browse files
committed
ConvNeXt pico, femto, and nano, pico, femto ols (overlapping stem) weights and model defs
1 parent 13565aa commit 2544d3b

File tree

1 file changed

+83
-21
lines changed

1 file changed

+83
-21
lines changed

timm/models/convnext.py

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
55
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
66
7+
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific.
8+
79
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
810
"""
911
# Copyright (c) Meta Platforms, Inc. and affiliates.
@@ -18,7 +20,8 @@
1820

1921
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2022
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
21-
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
23+
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
24+
create_conv2d, make_divisible
2225
from .registry import register_model
2326

2427

@@ -43,11 +46,26 @@ def _cfg(url='', **kwargs):
4346
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
4447

4548
# timm specific variants
49+
convnext_atto=_cfg(url=''),
50+
convnext_atto_ols=_cfg(url=''),
51+
convnext_femto=_cfg(
52+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
53+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
54+
convnext_femto_ols=_cfg(
55+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
56+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
57+
convnext_pico=_cfg(
58+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
59+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
60+
convnext_pico_ols=_cfg(
61+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
62+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
4663
convnext_nano=_cfg(
4764
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
4865
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
49-
convnext_nano_hnf=_cfg(url=''),
50-
convnext_nano_ols=_cfg(url=''),
66+
convnext_nano_ols=_cfg(
67+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
68+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
5169
convnext_tiny_hnf=_cfg(
5270
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
5371
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
@@ -236,8 +254,7 @@ def __init__(
236254
dims=(96, 192, 384, 768),
237255
ls_init_value=1e-6,
238256
stem_type='patch',
239-
stem_kernel_size=4,
240-
stem_stride=4,
257+
patch_size=4,
241258
head_init_scale=1.,
242259
head_norm_first=False,
243260
conv_mlp=False,
@@ -260,21 +277,22 @@ def __init__(
260277
self.drop_rate = drop_rate
261278
self.feature_info = []
262279

263-
assert stem_type in ('patch', 'overlap')
280+
assert stem_type in ('patch', 'overlap', 'overlap_tiered')
264281
if stem_type == 'patch':
265-
assert stem_kernel_size == stem_stride
266282
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
267283
self.stem = nn.Sequential(
268-
nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias),
284+
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
269285
norm_layer(dims[0])
270286
)
287+
stem_stride = patch_size
271288
else:
289+
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
272290
self.stem = nn.Sequential(
273-
nn.Conv2d(
274-
in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride,
275-
padding=stem_kernel_size // 2, bias=conv_bias),
291+
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
292+
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
276293
norm_layer(dims[0]),
277294
)
295+
stem_stride = 4
278296

279297
self.stages = nn.Sequential()
280298
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
@@ -415,29 +433,73 @@ def _create_convnext(variant, pretrained=False, **kwargs):
415433

416434

417435
@register_model
418-
def convnext_nano(pretrained=False, **kwargs):
419-
# timm nano variant with standard stem and head
436+
def convnext_atto(pretrained=False, **kwargs):
437+
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
420438
model_args = dict(
421-
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
422-
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
439+
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs)
440+
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
441+
return model
442+
443+
444+
@register_model
445+
def convnext_atto_ols(pretrained=False, **kwargs):
446+
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
447+
model_args = dict(
448+
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
449+
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
450+
return model
451+
452+
453+
@register_model
454+
def convnext_femto(pretrained=False, **kwargs):
455+
# timm femto variant
456+
model_args = dict(
457+
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs)
458+
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
459+
return model
460+
461+
462+
@register_model
463+
def convnext_femto_ols(pretrained=False, **kwargs):
464+
# timm femto variant
465+
model_args = dict(
466+
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
467+
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
468+
return model
469+
470+
471+
@register_model
472+
def convnext_pico(pretrained=False, **kwargs):
473+
# timm pico variant
474+
model_args = dict(
475+
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs)
476+
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
423477
return model
424478

425479

426480
@register_model
427-
def convnext_nano_hnf(pretrained=False, **kwargs):
428-
# experimental nano variant with normalization before pooling in head (head norm first)
481+
def convnext_pico_ols(pretrained=False, **kwargs):
482+
# timm nano variant with overlapping 3x3 conv stem
429483
model_args = dict(
430-
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
431-
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
484+
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
485+
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
486+
return model
487+
488+
489+
@register_model
490+
def convnext_nano(pretrained=False, **kwargs):
491+
# timm nano variant with standard stem and head
492+
model_args = dict(
493+
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
494+
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
432495
return model
433496

434497

435498
@register_model
436499
def convnext_nano_ols(pretrained=False, **kwargs):
437500
# experimental nano variant with overlapping conv stem
438501
model_args = dict(
439-
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True,
440-
stem_type='overlap', stem_kernel_size=9, **kwargs)
502+
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs)
441503
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
442504
return model
443505

0 commit comments

Comments
 (0)