@@ -109,6 +109,8 @@ def _cfg(url='', **kwargs):
109109 'vit_giant_patch14_224' : _cfg (url = '' ),
110110 'vit_gigantic_patch14_224' : _cfg (url = '' ),
111111
112+ 'vit_base2_patch32_256' : _cfg (url = '' , input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
113+
112114 # patch models, imagenet21k (weights from official Google JAX impl)
113115 'vit_tiny_patch16_224_in21k' : _cfg (
114116 url = 'https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz' ,
@@ -202,6 +204,7 @@ def _cfg(url='', **kwargs):
202204class Attention (nn .Module ):
203205 def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
204206 super ().__init__ ()
207+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
205208 self .num_heads = num_heads
206209 head_dim = dim // num_heads
207210 self .scale = head_dim ** - 0.5
@@ -634,6 +637,16 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
634637 return model
635638
636639
640+ @register_model
641+ def vit_base2_patch32_256 (pretrained = False , ** kwargs ):
642+ """ ViT-Base (ViT-B/32)
643+ # FIXME experiment
644+ """
645+ model_kwargs = dict (patch_size = 32 , embed_dim = 896 , depth = 12 , num_heads = 14 , ** kwargs )
646+ model = _create_vision_transformer ('vit_base2_patch32_256' , pretrained = pretrained , ** model_kwargs )
647+ return model
648+
649+
637650@register_model
638651def vit_base_patch32_384 (pretrained = False , ** kwargs ):
639652 """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
0 commit comments