3535import torch .utils .checkpoint
3636from torch .jit import Final
3737
38-
3938from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD , \
4039 OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4140from timm .layers import PatchEmbed , Mlp , DropPath , AttentionPoolLatent , RmsNorm , PatchDropout , SwiGLUPacked , \
@@ -1043,7 +1042,7 @@ def _cfg(url='', **kwargs):
10431042 }
10441043
10451044
1046- default_cfgs = generate_default_cfgs ( {
1045+ default_cfgs = {
10471046
10481047 # re-finetuned augreg 21k FT on in1k weights
10491048 'vit_base_patch16_224.augreg2_in21k_ft_in1k' : _cfg (
@@ -1459,49 +1458,60 @@ def _cfg(url='', **kwargs):
14591458 'vit_large_patch14_clip_224.dfn2b' : _cfg (
14601459 hf_hub_id = 'apple/DFN2B-CLIP-ViT-L-14' ,
14611460 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1461+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14621462 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
14631463 'vit_huge_patch14_clip_224.dfn5b' : _cfg (
14641464 hf_hub_id = 'apple/DFN5B-CLIP-ViT-H-14' ,
14651465 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1466+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14661467 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 1024 ),
14671468 'vit_huge_patch14_clip_378.dfn5b' : _cfg (
14681469 hf_hub_id = 'apple/DFN5B-CLIP-ViT-H-14-378' ,
14691470 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
14701471 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1472+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14711473 crop_pct = 1.0 , input_size = (3 , 378 , 378 ), num_classes = 1024 ),
14721474
14731475 'vit_base_patch32_clip_224.metaclip_2pt5b' : _cfg (
14741476 hf_hub_id = 'facebook/metaclip-b32-fullcc2.5b' ,
14751477 hf_hub_filename = 'metaclip_b32_fullcc2.5b.bin' ,
14761478 license = 'cc-by-nc-4.0' ,
1479+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14771480 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 512 ),
14781481 'vit_base_patch16_clip_224.metaclip_2pt5b' : _cfg (
14791482 hf_hub_id = 'facebook/metaclip-b16-fullcc2.5b' ,
14801483 hf_hub_filename = 'metaclip_b16_fullcc2.5b.bin' ,
14811484 license = 'cc-by-nc-4.0' ,
1485+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14821486 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 512 ),
14831487 'vit_large_patch14_clip_224.metaclip_2pt5b' : _cfg (
14841488 hf_hub_id = 'facebook/metaclip-l14-fullcc2.5b' ,
14851489 hf_hub_filename = 'metaclip_l14_fullcc2.5b.bin' ,
14861490 license = 'cc-by-nc-4.0' ,
1491+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14871492 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
14881493 'vit_huge_patch14_clip_224.metaclip_2pt5b' : _cfg (
14891494 hf_hub_id = 'facebook/metaclip-h14-fullcc2.5b' ,
14901495 hf_hub_filename = 'metaclip_h14_fullcc2.5b.bin' ,
14911496 license = 'cc-by-nc-4.0' ,
1497+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14921498 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 1024 ),
14931499
14941500 'vit_base_patch32_clip_224.openai' : _cfg (
14951501 hf_hub_id = 'timm/' ,
1502+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14961503 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
14971504 'vit_base_patch16_clip_224.openai' : _cfg (
14981505 hf_hub_id = 'timm/' ,
1506+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
14991507 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
15001508 'vit_large_patch14_clip_224.openai' : _cfg (
15011509 hf_hub_id = 'timm/' ,
1510+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
15021511 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
15031512 'vit_large_patch14_clip_336.openai' : _cfg (
15041513 hf_hub_id = 'timm/' , hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1514+ notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
15051515 mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
15061516 crop_pct = 1.0 , input_size = (3 , 336 , 336 ), num_classes = 768 ),
15071517
@@ -1677,7 +1687,25 @@ def _cfg(url='', **kwargs):
16771687 'vit_medium_patch16_reg4_gap_256' : _cfg (
16781688 input_size = (3 , 256 , 256 )),
16791689 'vit_base_patch16_reg8_gap_256' : _cfg (input_size = (3 , 256 , 256 )),
1690+ }
1691+
1692+ _quick_gelu_cfgs = [
1693+ 'vit_large_patch14_clip_224.dfn2b' ,
1694+ 'vit_huge_patch14_clip_224.dfn5b' ,
1695+ 'vit_huge_patch14_clip_378.dfn5b' ,
1696+ 'vit_base_patch32_clip_224.metaclip_2pt5b' ,
1697+ 'vit_base_patch16_clip_224.metaclip_2pt5b' ,
1698+ 'vit_large_patch14_clip_224.metaclip_2pt5b' ,
1699+ 'vit_huge_patch14_clip_224.metaclip_2pt5b' ,
1700+ 'vit_base_patch32_clip_224.openai' ,
1701+ 'vit_base_patch16_clip_224.openai' ,
1702+ 'vit_large_patch14_clip_224.openai' ,
1703+ 'vit_large_patch14_clip_336.openai' ,
1704+ ]
1705+ default_cfgs .update ({
1706+ n .replace ('_clip_' , '_clip_quickgelu_' ): default_cfgs [n ] for n in _quick_gelu_cfgs
16801707})
1708+ default_cfgs = generate_default_cfgs (default_cfgs )
16811709
16821710
16831711def _create_vision_transformer (variant , pretrained = False , ** kwargs ):
@@ -2133,8 +2161,7 @@ def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra
21332161 patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , pre_norm = True ,
21342162 norm_layer = nn .LayerNorm , act_layer = 'quick_gelu' )
21352163 model = _create_vision_transformer (
2136- 'vit_base_patch32_clip_224' , # map to non quickgelu pretrained_cfg intentionally
2137- pretrained = pretrained , ** dict (model_args , ** kwargs ))
2164+ 'vit_base_patch32_clip_quickgelu_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21382165 return model
21392166
21402167
@@ -2146,8 +2173,7 @@ def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra
21462173 patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , pre_norm = True ,
21472174 norm_layer = nn .LayerNorm , act_layer = 'quick_gelu' )
21482175 model = _create_vision_transformer (
2149- 'vit_base_patch16_clip_224' , # map to non quickgelu pretrained_cfg intentionally
2150- pretrained = pretrained , ** dict (model_args , ** kwargs ))
2176+ 'vit_base_patch16_clip_quickgelu_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21512177 return model
21522178
21532179
@@ -2160,8 +2186,7 @@ def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTr
21602186 patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , pre_norm = True ,
21612187 norm_layer = nn .LayerNorm , act_layer = 'quick_gelu' )
21622188 model = _create_vision_transformer (
2163- 'vit_large_patch14_clip_224' , # map to non quickgelu pretrained_cfg intentionally
2164- pretrained = pretrained , ** dict (model_args , ** kwargs ))
2189+ 'vit_large_patch14_clip_quickgelu_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21652190 return model
21662191
21672192
@@ -2173,8 +2198,7 @@ def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTr
21732198 patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , pre_norm = True ,
21742199 norm_layer = nn .LayerNorm , act_layer = 'quick_gelu' )
21752200 model = _create_vision_transformer (
2176- 'vit_large_patch14_clip_336' , # map to non quickgelu pretrained_cfg intentionally
2177- pretrained = pretrained , ** dict (model_args , ** kwargs ))
2201+ 'vit_large_patch14_clip_quickgelu_336' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21782202 return model
21792203
21802204
@@ -2186,8 +2210,7 @@ def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra
21862210 patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , pre_norm = True ,
21872211 norm_layer = nn .LayerNorm , act_layer = 'quick_gelu' )
21882212 model = _create_vision_transformer (
2189- 'vit_huge_patch14_clip_224' , # map to non quickgelu pretrained_cfg intentionally
2190- pretrained = pretrained , ** dict (model_args , ** kwargs ))
2213+ 'vit_huge_patch14_clip_quickgelu_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21912214 return model
21922215
21932216
@@ -2199,8 +2222,7 @@ def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTra
21992222 patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , pre_norm = True ,
22002223 norm_layer = nn .LayerNorm , act_layer = 'quick_gelu' )
22012224 model = _create_vision_transformer (
2202- 'vit_huge_patch14_clip_378' , # map to non quickgelu pretrained_cfg intentionally
2203- pretrained = pretrained , ** dict (model_args , ** kwargs ))
2225+ 'vit_huge_patch14_clip_quickgelu_378' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
22042226 return model
22052227
22062228
0 commit comments