-
Notifications
You must be signed in to change notification settings - Fork 0
/
node_embedder.py
100 lines (88 loc) · 3.81 KB
/
node_embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import math
import multiprocessing as mp
from gensim.models import Word2Vec
from random_walk import RandomWalk
from torch_geometric.datasets import Flickr
class NodeEmbedder:
def __init__(self, G, win_size, embedding_size, walks_per_vertex, \
walk_length, num_workers, percent_data, seed=36) -> None:
'''
Core Node embedder class.
Inputs:
- G (torch_geometric.data.Data): torch geometric data object.
- win_size (int): window size
- embedding_size (int): embedding size of output representation
- walks_per_vertex (int): number of walks per vertex
- walk_length (int): random walk length
- num_workers (int): number of CPU workers for SkipGram.
- percent_data (float): percentage of data to run algorithm on.
- seed (int): random seed. default set to 36.
'''
self.G = G
self.num_vertices = self.G[0].num_nodes
self.edges = self.G.edge_index
self.win_size = win_size
self.embedding_size = embedding_size
self.walks_per_vertex = walks_per_vertex
self.walk_length = walk_length
self.random_walk = RandomWalk(self.walks_per_vertex, self.walk_length)
self.num_workers = num_workers
self.percent_data = percent_data
self.skipgram = Word2Vec(
vector_size=self.embedding_size,
window=self.win_size,
workers=self.num_workers,
sg=1, # Default sg to 1 to leverage SkipGram and not CBOW
hs=1, # Default hs to 1 to leverage Hierarchical Softmax
min_count=1
)
def _shuffle(self, num_vertices):
'''Shuffle vertices'''
return torch.randperm(num_vertices)
def construct_all_walks(self, O):
'''
Construct all random walks.
Inputs:
- O (torch.Tensor): shuffled vertices.
'''
# Get all walks:
walks = [self.random_walk(self.edges, vertex.item()) for vertex in O]
return walks
def sample_from_graph(self) -> torch.Tensor:
'''
Update attributes as subsample of graph.
Outputs:
- sampled_nodes (torch.Tensor): sampled nodes from graph.
'''
# Randomly select the data we sample
num_sampled = math.ceil(self.num_vertices * self.percent_data)
shuffled_indices = self._shuffle(self.num_vertices)
sampled_nodes = shuffled_indices[:num_sampled]
return sampled_nodes
def construct_accesible_weights(self):
'''
Parse out relevant information for easily accessible weights.
'''
# Trained vertices may or may not be in order:
# Let's ensure they are and shuffle the embeddings accordingly:
tmp_vertices = torch.Tensor([int(i) for i in self.skipgram.wv.index_to_key])
i = torch.argsort(tmp_vertices)
self.trained_vertices = tmp_vertices[i]
self.trained_weights = torch.from_numpy(self.skipgram.syn1)[i,:]
def calculate_embeddings(self):
# Sample from graph:
sampled_nodes = self.sample_from_graph()
# Shuffle vertices and construct all walks:
for i in range(self.walks_per_vertex):
O = sampled_nodes[self._shuffle(len(sampled_nodes))]
all_walks = self.construct_all_walks(O)
update = False if i == 0 else True
self.skipgram.build_vocab(corpus_iterable=all_walks, update=update)
self.skipgram.train(
corpus_iterable=all_walks,
total_examples=self.skipgram.corpus_count,
epochs=1 # doesn't default automatically for some reason
)
# Once we're good with training, make everything easily accessible
self.construct_accesible_weights()