forked from thunlp/OpenKE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_subgraphs.py
68 lines (59 loc) · 2.43 KB
/
create_subgraphs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import copy
import os
import sys
import json
from random import shuffle
import random
import argparse
import pickle
import numpy as np
from collections import defaultdict
from subgraphs import Subgraph
from subgraphs import SubgraphFactory
from subgraphs import SUBTYPE
from subgraphs import read_triples
import torch
import kge.model
def parse_args():
parser = argparse.ArgumentParser(description = 'Read embeddings and prepare subgraphs.')
parser.add_argument('--db', required = True, dest = 'db', type = str, default = None)
parser.add_argument('--embfile', required = True, dest = 'embfile', type = str, help = 'File containing embeddings.')
parser.add_argument('--type', required = True, dest = 'subtype', type = str, default = "star", help = 'Subgraph type.')
parser.add_argument('-rd', '--result-dir', dest ='result_dir', type = str, default = "/var/scratch/dvs254/OpenKE-results/", help = 'Output dir.')
parser.add_argument('--infile', dest ='infile', type = str, help = 'File containing training triples.', default = "/home/dvs254/OpenKE/benchmarks/fb15k237/train2id.txt")
parser.add_argument('--ms', dest = 'ms', type = int, default = 10, help = 'Minimum subgraph Size')
parser.add_argument('--model', dest = 'model', type = str, default = "transe", help = 'Embedding model')
return parser.parse_args()
args = parse_args()
ms = args.ms
db = args.db
result_dir = args.result_dir + db + "/subgraphs/"
os.makedirs(result_dir, exist_ok = True)
def read_json_file(filename):
with open(filename, "r") as fin:
params = json.loads(fin.read())
return params
# read complex embeddings from the LibKGE
#'./local/fb15k-237-complex.pt'
def read_complex_embeddings(filename):
params = read_json_file(filename)
E = params['ent_re_embeddings.weight'] + params['ent_im_embeddings.weight']
R = params['rel_re_embeddings.weight'] + params['rel_im_embeddings.weight']
return E, R
def read_embeddings(filename):
params = read_json_file(filename)
E = params['ent_embeddings.weight']
R = params['rel_embeddings.weight']
return E, R
triples = read_triples(args.infile)
if args.model == "complex":
E, R = read_complex_embeddings(args.embfile)
else:
E, R = read_embeddings(args.embfile)
#print(type(E))
#print(len(E))
#print(len(E[0]))
#print(type(R))
sub_factory = SubgraphFactory(args.db, int(args.ms), triples, E)
sub_factory.make_subgraphs(args.subtype)
sub_factory.save(result_dir, args.model)