-
Notifications
You must be signed in to change notification settings - Fork 0
/
classification_models.py
74 lines (58 loc) · 1.65 KB
/
classification_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import numpy as np
import components as c
from collections import OrderedDict
torch.manual_seed(225530)
np.random.seed(225530)
class FastNet(nn.Module):
def __init__(self):
super(FastNet, self).__init__()
self.prototype = None
self.relu = nn.ReLU()
self.network, self.no_fast_weights = self._build_network()
@staticmethod
def _count_weights(weights):
total_features = 0
for name, module in weights:
try:
total_features += module.no_fast_weights
except AttributeError:
pass
return total_features
def _build_network(self):
raise NotImplementedError
def reset(self):
self.network, _ = self._build_network()
def forward(self, x):
return self.network(x)
class FeedForwardBaseline(nn.Module):
def __init__(self, device):
super(FeedForwardBaseline, self).__init__()
self.device = device
self.linear1 = nn.Linear(3, 100, device=self.device)
self.linear2 = nn.Linear(100, 1, device=self.device)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self.sigmoid(self.linear2(self.linear1(x)))
class FeedForwardFastNet(FastNet):
def __init__(self, device="cpu"):
self.device = device
super(FeedForwardFastNet, self).__init__()
self.no_from = 6
self.no_to = 4
def _build_network(self):
if self.prototype is not None:
for name, mod in self.prototype:
mod.reset()
self.prototype = [
("linear1", c.Linear(3, 100, device=self.device)),
("linear2", c.Linear(100, 1, device=self.device)),
("sigmoid", nn.Sigmoid())
]
def forward(x):
cur = x
for module in self.prototype:
cur = module[1](cur)
return cur
return forward, self._count_weights(self.prototype)