@@ -50,6 +50,12 @@ def _cfg(url='', **kwargs):
5050 url = "https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth" ,
5151 crop_pct = 0.95 , test_input_size = (3 , 320 , 320 ), test_crop_pct = 1.0 ,
5252 ),
53+ # edgenext_base=_cfg(
54+ # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
55+ edgenext_base = _cfg ( # USI weights
56+ url = "https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth" ,
57+ crop_pct = 0.95 , test_input_size = (3 , 320 , 320 ), test_crop_pct = 1.0 ,
58+ ),
5359
5460 edgenext_small_rw = _cfg (
5561 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth' ,
@@ -154,7 +160,7 @@ def __init__(
154160
155161 def forward (self , x ):
156162 B , N , C = x .shape
157- qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , C // self . num_heads ).permute (2 , 0 , 3 , 4 , 1 )
163+ qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 4 , 1 )
158164 q , k , v = qkv .unbind (0 )
159165
160166 # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
@@ -217,7 +223,8 @@ def forward(self, x):
217223 shortcut = x
218224
219225 # scales code re-written for torchscript as per my res2net fixes -rw
220- spx = torch .split (x , self .width , 1 )
226+ # NOTE torch.split(x, self.width, 1) causing issues with ONNX export
227+ spx = x .chunk (len (self .convs ) + 1 , dim = 1 )
221228 spo = []
222229 sp = spx [0 ]
223230 for i , conv in enumerate (self .convs ):
@@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs):
545552 return _create_edgenext ('edgenext_small' , pretrained = pretrained , ** model_kwargs )
546553
547554
555+ @register_model
556+ def edgenext_base (pretrained = False , ** kwargs ):
557+ # 18.51M & 3840.93M @ 256 resolution
558+ # 82.5% (normal) 83.7% (USI) Top-1 accuracy
559+ # AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
560+ # Jetson FPS=xx.xx versus xx.xx for MobileViT_S
561+ # For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
562+ model_kwargs = dict (depths = [3 , 3 , 9 , 3 ], dims = [80 , 160 , 288 , 584 ], ** kwargs )
563+ return _create_edgenext ('edgenext_base' , pretrained = pretrained , ** model_kwargs )
564+
565+
548566@register_model
549567def edgenext_small_rw (pretrained = False , ** kwargs ):
550- # 5.59M & 1260.59M @ 256 resolution
551- # 79.43% Top-1 accuracy
552- # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
553- # Jetson FPS=20.47 versus 18.86 for MobileViT_S
554- # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
555568 model_kwargs = dict (
556569 depths = (3 , 3 , 9 , 3 ), dims = (48 , 96 , 192 , 384 ),
557570 downsample_block = True , conv_bias = False , stem_type = 'overlap' , ** kwargs )
0 commit comments