44
55Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
66
7+ Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific.
8+
79Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
810"""
911# Copyright (c) Meta Platforms, Inc. and affiliates.
1820
1921from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2022from .helpers import named_apply , build_model_with_cfg , checkpoint_seq
21- from .layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , ConvMlp , Mlp , LayerNorm2d , create_conv2d
23+ from .layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , ConvMlp , Mlp , LayerNorm2d ,\
24+ create_conv2d , make_divisible
2225from .registry import register_model
2326
2427
@@ -43,11 +46,26 @@ def _cfg(url='', **kwargs):
4346 convnext_large = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth" ),
4447
4548 # timm specific variants
49+ convnext_atto = _cfg (url = '' ),
50+ convnext_atto_ols = _cfg (url = '' ),
51+ convnext_femto = _cfg (
52+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth' ,
53+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
54+ convnext_femto_ols = _cfg (
55+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth' ,
56+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
57+ convnext_pico = _cfg (
58+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth' ,
59+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
60+ convnext_pico_ols = _cfg (
61+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth' ,
62+ crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
4663 convnext_nano = _cfg (
4764 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth' ,
4865 crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
49- convnext_nano_hnf = _cfg (url = '' ),
50- convnext_nano_ols = _cfg (url = '' ),
66+ convnext_nano_ols = _cfg (
67+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth' ,
68+ crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
5169 convnext_tiny_hnf = _cfg (
5270 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth' ,
5371 crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
@@ -236,8 +254,7 @@ def __init__(
236254 dims = (96 , 192 , 384 , 768 ),
237255 ls_init_value = 1e-6 ,
238256 stem_type = 'patch' ,
239- stem_kernel_size = 4 ,
240- stem_stride = 4 ,
257+ patch_size = 4 ,
241258 head_init_scale = 1. ,
242259 head_norm_first = False ,
243260 conv_mlp = False ,
@@ -260,21 +277,22 @@ def __init__(
260277 self .drop_rate = drop_rate
261278 self .feature_info = []
262279
263- assert stem_type in ('patch' , 'overlap' )
280+ assert stem_type in ('patch' , 'overlap' , 'overlap_tiered' )
264281 if stem_type == 'patch' :
265- assert stem_kernel_size == stem_stride
266282 # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
267283 self .stem = nn .Sequential (
268- nn .Conv2d (in_chans , dims [0 ], kernel_size = stem_kernel_size , stride = stem_stride , bias = conv_bias ),
284+ nn .Conv2d (in_chans , dims [0 ], kernel_size = patch_size , stride = patch_size , bias = conv_bias ),
269285 norm_layer (dims [0 ])
270286 )
287+ stem_stride = patch_size
271288 else :
289+ mid_chs = make_divisible (dims [0 ] // 2 ) if 'tiered' in stem_type else dims [0 ]
272290 self .stem = nn .Sequential (
273- nn .Conv2d (
274- in_chans , dims [0 ], kernel_size = stem_kernel_size , stride = stem_stride ,
275- padding = stem_kernel_size // 2 , bias = conv_bias ),
291+ nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
292+ nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
276293 norm_layer (dims [0 ]),
277294 )
295+ stem_stride = 4
278296
279297 self .stages = nn .Sequential ()
280298 dp_rates = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
@@ -415,29 +433,73 @@ def _create_convnext(variant, pretrained=False, **kwargs):
415433
416434
417435@register_model
418- def convnext_nano (pretrained = False , ** kwargs ):
419- # timm nano variant with standard stem and head
436+ def convnext_atto (pretrained = False , ** kwargs ):
437+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
420438 model_args = dict (
421- depths = (2 , 2 , 8 , 2 ), dims = (80 , 160 , 320 , 640 ), conv_mlp = True , ** kwargs )
422- model = _create_convnext ('convnext_nano' , pretrained = pretrained , ** model_args )
439+ depths = (2 , 2 , 6 , 2 ), dims = (40 , 80 , 160 , 320 ), conv_mlp = True , ** kwargs )
440+ model = _create_convnext ('convnext_atto' , pretrained = pretrained , ** model_args )
441+ return model
442+
443+
444+ @register_model
445+ def convnext_atto_ols (pretrained = False , ** kwargs ):
446+ # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
447+ model_args = dict (
448+ depths = (2 , 2 , 6 , 2 ), dims = (40 , 80 , 160 , 320 ), conv_mlp = True , stem_type = 'overlap_tiered' , ** kwargs )
449+ model = _create_convnext ('convnext_atto_ols' , pretrained = pretrained , ** model_args )
450+ return model
451+
452+
453+ @register_model
454+ def convnext_femto (pretrained = False , ** kwargs ):
455+ # timm femto variant
456+ model_args = dict (
457+ depths = (2 , 2 , 6 , 2 ), dims = (48 , 96 , 192 , 384 ), conv_mlp = True , ** kwargs )
458+ model = _create_convnext ('convnext_femto' , pretrained = pretrained , ** model_args )
459+ return model
460+
461+
462+ @register_model
463+ def convnext_femto_ols (pretrained = False , ** kwargs ):
464+ # timm femto variant
465+ model_args = dict (
466+ depths = (2 , 2 , 6 , 2 ), dims = (48 , 96 , 192 , 384 ), conv_mlp = True , stem_type = 'overlap_tiered' , ** kwargs )
467+ model = _create_convnext ('convnext_femto_ols' , pretrained = pretrained , ** model_args )
468+ return model
469+
470+
471+ @register_model
472+ def convnext_pico (pretrained = False , ** kwargs ):
473+ # timm pico variant
474+ model_args = dict (
475+ depths = (2 , 2 , 6 , 2 ), dims = (64 , 128 , 256 , 512 ), conv_mlp = True , ** kwargs )
476+ model = _create_convnext ('convnext_pico' , pretrained = pretrained , ** model_args )
423477 return model
424478
425479
426480@register_model
427- def convnext_nano_hnf (pretrained = False , ** kwargs ):
428- # experimental nano variant with normalization before pooling in head (head norm first)
481+ def convnext_pico_ols (pretrained = False , ** kwargs ):
482+ # timm nano variant with overlapping 3x3 conv stem
429483 model_args = dict (
430- depths = (2 , 2 , 8 , 2 ), dims = (80 , 160 , 320 , 640 ), head_norm_first = True , conv_mlp = True , ** kwargs )
431- model = _create_convnext ('convnext_nano_hnf' , pretrained = pretrained , ** model_args )
484+ depths = (2 , 2 , 6 , 2 ), dims = (64 , 128 , 256 , 512 ), conv_mlp = True , stem_type = 'overlap_tiered' , ** kwargs )
485+ model = _create_convnext ('convnext_pico_ols' , pretrained = pretrained , ** model_args )
486+ return model
487+
488+
489+ @register_model
490+ def convnext_nano (pretrained = False , ** kwargs ):
491+ # timm nano variant with standard stem and head
492+ model_args = dict (
493+ depths = (2 , 2 , 8 , 2 ), dims = (80 , 160 , 320 , 640 ), conv_mlp = True , ** kwargs )
494+ model = _create_convnext ('convnext_nano' , pretrained = pretrained , ** model_args )
432495 return model
433496
434497
435498@register_model
436499def convnext_nano_ols (pretrained = False , ** kwargs ):
437500 # experimental nano variant with overlapping conv stem
438501 model_args = dict (
439- depths = (2 , 2 , 8 , 2 ), dims = (80 , 160 , 320 , 640 ), conv_mlp = True ,
440- stem_type = 'overlap' , stem_kernel_size = 9 , ** kwargs )
502+ depths = (2 , 2 , 8 , 2 ), dims = (80 , 160 , 320 , 640 ), conv_mlp = True , stem_type = 'overlap' , ** kwargs )
441503 model = _create_convnext ('convnext_nano_ols' , pretrained = pretrained , ** model_args )
442504 return model
443505
0 commit comments