Skip to content

Commit

Permalink
Additions in the training procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
aniket-agarwal1999 committed Dec 1, 2019
1 parent 80e412b commit d2706d2
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 6 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# vGraph: A Generative Model For Joint Community Detection and Node Representational Learning

This is a Pytorch implementation of the paper [*vGraph*](https://arxiv.org/abs/1906.07159) using the Pytorch Geometric and done under the NeurIPS Reproducibility Challenge 2019.

## Setup Instructions and Dependancies

14 changes: 13 additions & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,23 @@
import torch_geometric as pyg
from torch_geometric.data import Dataset
import torch_geometric.datasets as datasets
import pandas as pd

def get_cora():
dataset = datasets.Planetoid(root='./dataset/Cora', name='Cora')
return dataset.data

def get_citeseer():
dataset = datasets.Planetoid(root='./dataset/Citeseer', name='CiteSeer')
return dataset.data
return dataset.data

def get_facebook(code):
'''
code: which graph to use from the available facebook social circles subgraphs
options of code : [0, 107, 1684, 1912, 3437, 348, 3980, 414, 686, 698]
'''
assert code in [0, 107, 1684, 1912, 3437, 348, 3980, 414, 686, 698]
edge_file = './dataset/Facebook/'+str(code)+'.edges'
label_file = './dataset/Facebook/'+str(code)+'.circles' ### Since the circles file basically contains the ground truth

for
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def negative_sampling_line(z, edge_index, num_negative_samples = 5):
Parameters:
z: the sampled community using gumbel softmax reparametrization trick
edge_index: edges in the graph
negative_samples: number of negative samples to be used for the optimization
num_negative_samples: number of negative samples to be used for the optimization
The function has been partially inspired from this file: https://github.com/DMPierre/LINE/blob/master/utils/line.py
'''
Expand Down Expand Up @@ -84,6 +84,6 @@ def forward(self, w, c, edge_index, num_negative_samples=5):
else:
recon_c = self.negative_sampling_line(z, edge_index, num_negative_samples)

return prior, recon_c, F.softmax(q, dim=-1)
return prior, recon_c, F.softmax(q, dim=-1) ### Here the end term F.softmax(q) cannot be gumbel_softmax


14 changes: 14 additions & 0 deletions train.py → train_nonoverlapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
start_epoch = 0

for epoch in range(start_epoch, args.epochs):

optimizer.zero_grad()
model.train()

w = torch.cat((edge_index[0, :], edge_index[1, :]))
c = torch.cat((edge_index[1, :], edge_index[0, :]))
Expand All @@ -93,11 +96,22 @@
optimizer.step()

print('Epoch: ', epoch+1, ' done!!')
print('Total error: ', total_loss)

if epoch % 100 == 0:
lr_scheduler.step()
modularity, macro_F1, micro_F1 = utils.calculate_nonoverlap_losses(model, dataset, edge_index)
f = open(args.dataset + '_results.txt', 'a+')
f.write('Epoch :', epoch, ' modularity: ', modularity, ' macro_F1: ', macro_F1, ' micro_F1: ', micro_F1, ' \n')


writer_tensorboard.add_scalars('Total Loss', {'vgraph_loss':vgraph_loss, 'regularization_loss':regularization_loss}, epoch)

### Saving the checkpoint
utils.save_checkpoint({'epoch':epoch+1,
'model':model.state_dict(),
'optimizer':optimizer.state_dict()},
args.checkpoint_dir + '/latest_model_'+args.dataset+'.ckpt')


writer_tensorboard.close()
Expand Down
Empty file added train_overlapping.py
Empty file.
52 changes: 49 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import torch.nn as nn
import torch_geometric as pyg
import torch.nn.functional as F
from sklearn import metrics
from sklearn.cluster import KMeans
import community ## For calculating modularity

def vGraph_loss(c, recon_c, prior, q):
'''
Expand All @@ -11,8 +14,9 @@ def vGraph_loss(c, recon_c, prior, q):
q = q(z|c,w)
'''

BCE_loss = F.cross_entropy(recon_c, c)
KL_div_loss = F.kl_div(torch.log(prior + 1e-20), q, reduction='batchmean')
BCE_loss = F.cross_entropy(recon_c, c) / c.shape[0] ### Normalization is necessary or the dimension of c is too large and it will be the most weighted
# KL_div_loss = F.kl_div(torch.log(prior + 1e-20), q, reduction='batchmean')
KL_div_loss = torch.sum(q*(torch.log(q + 1e-20) - torch.log(prior)), -1).mean() ## As such main use is of just mean()

loss = BCE_loss + KL_div_loss
return loss
Expand All @@ -22,6 +26,12 @@ def load_checkpoint(ckpt_path, map_location='cpu'):
print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
return ckpt

def save_checkpoint(state, save_path):
'''
Saving checkpoints(state) at the specified save_path location
'''
torch.save(state, save_path)

def similarity_measure(edge_index, w, c, gpu_id):
'''
Used for calculating the coefficient alpha in the case of community smoothness loss
Expand Down Expand Up @@ -54,5 +64,41 @@ def cuda(xs, gpu_id):
return xs


def calculate_nonoverlap_losses(model, dataset, edge_index)
'''
For calculating losses pertaining to the non-overlapping dataset, namely, Macro F1, Micro F1, Modularity, NMI
'''
model.eval()
labels = dataset.y
w = edge_index[0, :]
c = edge_index[1, :]
_, _, q = model(w, c, edge_index)

new_labels = torch.zeros(w.shape[0], 1)
for i in range(w.shape[0]):
new_labels[i] = labels[w[i]]


kmeans = KMeans(n_clusters=torch.unique(labels).shape[0], random_state=0).fit(q)

###For calculating modularity
assignment = {i: int(kmeans.labels_[i]) for i in range(q.shape[0])}
networkx_graph = pyg.utils.to_networkx(dataset)
modularity = community.modularity(assignment, dataset)

###For calculating macro and micro F1 score
macro_F1 = metrics.f1_score(new_labels.numpy(), kmeans.labels_, average='macro')
micro_F1 = metrics.f1_score(new_labels.numpy(), kmeans.labels_, average='micro')

return modularity, macro_F1, micro_F1

def calculate_jaccard():
'''
## This is for the overlapping case
'''



def calculate_f1():
'''
## This is for the overlapping case
'''

0 comments on commit d2706d2

Please sign in to comment.