-
Notifications
You must be signed in to change notification settings - Fork 69
/
hubconf.py
47 lines (39 loc) · 2.52 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
dependencies = [
'timm',
'torch',
]
import torch, timm
__all__ = ['mealv1_resnest50', 'mealv2_resnest50', 'mealv2_resnest50_cutmix', 'mealv2_resnest50_380x380', 'mealv2_mobilenetv3_small_075', 'mealv2_mobilenetv3_small_100', 'mealv2_mobilenet_v3_large_100', 'mealv2_efficientnet_b0']
model_urls = {
'mealv1_resnest50': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV1_ResNet50_224.pth',
'mealv2_resnest50': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_224.pth',
'mealv2_resnest50_cutmix': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_224_cutmix.pth',
'mealv2_resnest50_380x380': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_ResNet50_380.pth',
'mealv2_mobilenetv3_small_075': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Small_0.75_224.pth',
'mealv2_mobilenetv3_small_100': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Small_1.0_224.pth',
'mealv2_mobilenet_v3_large_100': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_MobileNet_V3_Large_1.0_224.pth',
'mealv2_efficientnet_b0': 'https://github.com/szq0214/MEAL-V2/releases/download/v1.0.0/MEALV2_EfficientNet_B0_224.pth',
}
mapping = {'mealv1_resnest50':'resnet50',
'mealv2_resnest50':'resnet50',
'mealv2_resnest50_cutmix':'resnet50',
'mealv2_resnest50_380x380':'resnet50',
'mealv2_mobilenetv3_small_075':'tf_mobilenetv3_small_075',
'mealv2_mobilenetv3_small_100':'tf_mobilenetv3_small_100',
'mealv2_mobilenet_v3_large_100':'tf_mobilenetv3_large_100',
'mealv2_efficientnet_b0':'tf_efficientnet_b0'
}
def meal_v2(model_name, pretrained=True, progress=True, exportable=False):
""" MEAL V2 models from
`"MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks" <https://arxiv.org/pdf/2009.08453.pdf>`_
Args:
model_name: Name of the model to load
pretrained (bool): If True, returns a model trained with MEAL V2 on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = timm.create_model(mapping[model_name.lower()], pretrained=False, exportable=exportable)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(model_urls[model_name.lower()], progress=progress)
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(state_dict)
return model