-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutilitis.py
69 lines (51 loc) · 2.11 KB
/
utilitis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
'''
setup model and datasets
'''
import copy
import torch
import numpy as np
# from advertorch.utils import NormalizeByChannelMeanStd
from models import *
from dataset import *
__all__ = ['setup_model_dataset']
def setup_model_dataset(args):
if args.dataset == 'cifar10':
classes = 10
normalization = NormalizeByChannelMeanStd(
mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
train_set_loader, val_loader, test_loader = cifar10_dataloaders(
batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers)
elif args.dataset == 'cifar100':
classes = 100
normalization = NormalizeByChannelMeanStd(
mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762])
train_set_loader, val_loader, test_loader = cifar100_dataloaders(
batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers)
else:
raise ValueError('Dataset not supprot yet !')
if args.imagenet_arch:
model = model_dict[args.arch](num_classes=classes, imagenet=True)
else:
model = model_dict[args.arch](num_classes=classes)
model.normalize = normalization
print(model)
return model, train_set_loader, val_loader, test_loader
class NormalizeByChannelMeanStd(torch.nn.Module):
def __init__(self, mean, std):
super(NormalizeByChannelMeanStd, self).__init__()
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, tensor):
return self.normalize_fn(tensor, self.mean, self.std)
def extra_repr(self):
return 'mean={}, std={}'.format(self.mean, self.std)
def normalize_fn(self, tensor, mean, std):
"""Differentiable version of torchvision.functional.normalize"""
# here we assume the color channel is in at dim=1
mean = mean[None, :, None, None]
std = std[None, :, None, None]
return tensor.sub(mean).div(std)