Skip to content

Commit 13565aa

Browse files
committed
Add edgenext_base model def & weight link, update to improve ONNX export #1385
1 parent 56596e4 commit 13565aa

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

timm/models/edgenext.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def _cfg(url='', **kwargs):
5050
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
5151
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
5252
),
53+
# edgenext_base=_cfg(
54+
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
55+
edgenext_base=_cfg( # USI weights
56+
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
57+
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
58+
),
5359

5460
edgenext_small_rw=_cfg(
5561
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
@@ -154,7 +160,7 @@ def __init__(
154160

155161
def forward(self, x):
156162
B, N, C = x.shape
157-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
163+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
158164
q, k, v = qkv.unbind(0)
159165

160166
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
@@ -217,7 +223,8 @@ def forward(self, x):
217223
shortcut = x
218224

219225
# scales code re-written for torchscript as per my res2net fixes -rw
220-
spx = torch.split(x, self.width, 1)
226+
# NOTE torch.split(x, self.width, 1) causing issues with ONNX export
227+
spx = x.chunk(len(self.convs) + 1, dim=1)
221228
spo = []
222229
sp = spx[0]
223230
for i, conv in enumerate(self.convs):
@@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs):
545552
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
546553

547554

555+
@register_model
556+
def edgenext_base(pretrained=False, **kwargs):
557+
# 18.51M & 3840.93M @ 256 resolution
558+
# 82.5% (normal) 83.7% (USI) Top-1 accuracy
559+
# AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
560+
# Jetson FPS=xx.xx versus xx.xx for MobileViT_S
561+
# For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
562+
model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs)
563+
return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs)
564+
565+
548566
@register_model
549567
def edgenext_small_rw(pretrained=False, **kwargs):
550-
# 5.59M & 1260.59M @ 256 resolution
551-
# 79.43% Top-1 accuracy
552-
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
553-
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
554-
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
555568
model_kwargs = dict(
556569
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
557570
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)

0 commit comments

Comments
 (0)