-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
40 lines (35 loc) · 1.46 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphConvolution
class GCNv2(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, layer, activation, act_before_dropout=False):
super(GCNv2, self).__init__()
self.dropout = dropout
self.num_layer = layer
self.activation = activation
self.layers = nn.ModuleList()
if self.num_layer == 1:
self.layers.append(GraphConvolution(nfeat, nclass, nfeat))
else:
self.layers.append(GraphConvolution(nfeat, nhid, nfeat))
for i in range(self.num_layer-2):
self.layers.append(GraphConvolution(nhid, nhid, nfeat))
self.layers.append(GraphConvolution(nhid, nclass, nfeat))
self.act_before_dropout = act_before_dropout
def forward(self, adj, x):
h = x.clone().detach()
for i, layer in enumerate(self.layers):
if i == self.num_layer - 1:
# last layer without activation or dropout
x = layer(x, adj, h)
else:
if self.act_before_dropout:
x = self.activation(layer(x, adj, h))
x = F.dropout(x, self.dropout, training=self.training)
else:
x = F.dropout(x, self.dropout, training=self.training)
x = self.activation(layer(x, adj, h))
return x