-
Notifications
You must be signed in to change notification settings - Fork 1
/
macs_profiling.py
79 lines (64 loc) · 3.35 KB
/
macs_profiling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torchvision
class ProfileConv(nn.Module):
def __init__(self, model):
super(ProfileConv, self).__init__()
self.model = model
self.hooks = []
self.macs = []
self.params = []
def hook_conv(module, input, output):
self.macs.append(output.size(1) * output.size(2) * output.size(3) *
module.weight.size(-1) * module.weight.size(-1) * input[0].size(1) / module.groups)
self.params.append(module.weight.size(0) * module.weight.size(1) *
module.weight.size(2) * module.weight.size(3) + module.weight.size(1))
def hook_linear(module, input, output):
if len(input[0].size()) > 2:
self.macs.append(module.weight.size(0) * module.weight.size(1) * input[0].size(-2))
else:
self.macs.append(module.weight.size(0) * module.weight.size(1))
self.params.append(module.weight.size(0) * module.weight.size(1) + module.bias.size(0))
def hook_gelu(module, input, output):
if len(output[0].size()) > 3:
self.macs.append(output.size(1) * output.size(2) * output.size(3))
else:
self.macs.append(output.size(1) * output.size(2))
def hook_layernorm(module, input, output):
self.macs.append(2 * input[0].size(1) * input[0].size(2))
self.params.append(module.weight.size(0) + module.bias.size(0))
def hook_avgpool(module, input, output):
self.macs.append(output.size(1) * output.size(2) * output.size(3) * module.kernel_size * module.kernel_size)
def hook_attention(module, input, output):
self.macs.append(module.key_dim * (module.resolution ** 4) * module.num_heads +
module.dh * (module.resolution ** 4))
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
self.hooks.append(module.register_forward_hook(hook_conv))
elif isinstance(module, nn.Linear):
self.hooks.append(module.register_forward_hook(hook_linear))
elif isinstance(module, nn.GELU):
self.hooks.append(module.register_forward_hook(hook_gelu))
elif isinstance(module, nn.LayerNorm):
self.hooks.append(module.register_forward_hook(hook_layernorm))
elif isinstance(module, nn.AvgPool2d):
self.hooks.append(module.register_forward_hook(hook_avgpool))
# elif isinstance(module, Attention):
# self.hooks.append(module.register_forward_hook(hook_attention))
def forward(self, x):
self.model.to(x.device)
_ = self.model(x)
for handle in self.hooks:
handle.remove()
return self.macs, self.params
if __name__ == '__main__':
# find the 'out = model(x)' in your code, my method is based on pytorch hook
# example input
x = torch.randn(1, 3, 224, 224)
# example model
model = torchvision.models.mobilenet_v2(pretrained=False)
profile = ProfileConv(model)
MACs, params = profile(x)
print('number of conv&fc layers:', len(MACs))
print(sum(MACs) / 1e9, 'GMACs')
print(sum(params) / 1e6, 'M parameters')