-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_graph.py
99 lines (80 loc) · 3.11 KB
/
build_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import os
import timeit
import dgl
import torch
from dgl.data.utils import save_graphs
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vectors
from tqdm import tqdm
tokenizer = get_tokenizer('basic_english')
def get_edge_and_node_fatures(MeSH_id_pair_file, parent_children_file, vectors):
"""
:param file:
:return: edge: a list of nodes pairs [(node1, node2), (node3, node4), ...] (39904 relations)
node_count: int, number of nodes in the graph
node_features: a Tensor with size [num_of_nodes, embedding_dim]
"""
print('load MeSH id and names')
# get descriptor and MeSH mapped
mapping_id = {}
with open(MeSH_id_pair_file, 'r') as f:
for line in f:
(key, value) = line.split('=')
mapping_id[key] = value.strip()
# count number of nodes and get edges
print('count number of nodes and get edges of the graph')
node_count = len(mapping_id)
print('number of nodes: ', node_count)
mesh_id = list(mapping_id.keys())
edges = []
with open(parent_children_file, 'r') as f:
for line in f:
item = tuple(line.strip().split(" "))
index_item = (mesh_id.index(item[0]), mesh_id.index(item[1]))
edges.append(index_item)
print('number of edges: ', len(edges))
print('get label embeddings')
label_embedding = torch.zeros(0)
for key, value in tqdm(mapping_id.items()):
key = tokenizer(key)
key = [k.lower() for k in key]
embedding = []
for k in key:
embedding.append(vectors.__getitem__(k))
key_embedding = torch.mean(torch.stack(embedding), dim=0, keepdim=True)
label_embedding = torch.cat((label_embedding, key_embedding), dim=0)
return edges, node_count, label_embedding
def build_MeSH_graph(edge_list, nodes, label_embedding):
print('start building the graph')
g = dgl.DGLGraph()
# add nodes into the graph
print('add nodes into the graph')
g.add_nodes(nodes)
# add edges, directional graph
print('add edges into the graph')
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
# add node features into the graph
print('add node features into the graph')
g.ndata['feat'] = label_embedding
return g
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--word2vec_path')
parser.add_argument('--meSH_pair_path')
parser.add_argument('--mesh_parent_children_path')
parser.add_argument('--output')
args = parser.parse_args()
print('Load pre-trained vectors')
cache, name = os.path.split(args.word2vec_path)
vectors = Vectors(name=name, cache=cache)
edges, node_count, label_embedding = get_edge_and_node_fatures(args.meSH_pair_path, args.mesh_parent_children_path,
vectors)
G = build_MeSH_graph(edges, node_count, label_embedding)
save_graphs(args.output, G)
if __name__ == "__main__":
start = timeit.default_timer()
main()
stop = timeit.default_timer()
print('Time: ', stop - start)