forked from thunlp/OpenKE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
read_answers_and_pickle.py
105 lines (91 loc) · 4.15 KB
/
read_answers_and_pickle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import sys
import json
from random import shuffle
import random
import argparse
import pickle
import numpy as np
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description = 'Read answer embeddings based on topk and prepare triples(training/test) for LSTM/other RNN models.')
parser.add_argument('--embfile', dest ='embfile', type = str, help = 'File containing embeddings.')
parser.add_argument('-rd', '--result-dir', dest ='result_dir', type = str, default = "/var/scratch2/uji300/OpenKE-results/",help = 'Output dir.')
parser.add_argument('--combemb', dest ='combine_emb', help = 'Whether to combine embeddings of ent and rel in input', action = 'store_true')
parser.add_argument('--topk', dest = 'topk', required = True, type = int, default = 10)
parser.add_argument('--db', required = True, dest = 'db', type = str, default = None)
parser.add_argument('--mode', required = True, dest = 'mode', type = str, default = None, help = "train or test")
parser.add_argument('--ansfile', dest ='ansfile', type = str, help = 'File containing answers as predicted by the model.')
return parser.parse_args()
args = parse_args()
topk = args.topk
db = args.db
result_dir = args.result_dir + args.db + "/"
os.makedirs(result_dir, exist_ok = True)
combine_embeddings = args.combine_emb
# Read embedding file
print("Reading embeddings file...", end=" ")
with open(args.embfile, "r") as fin:
params = json.loads(fin.read())
embeddings = params['ent_embeddings.weight']
rel_embeddings = params['rel_embeddings.weight']
print("DONE")
# Read the answers file (generated from the test option of the model)
print("Reading answers file...", end=" ")
with open(args.ansfile, "r") as fin:
res = json.loads(fin.read())
print("DONE")
triples= {}
if args.mode == "train":
rf_arr = [""]
else:
rf_arr = ["_raw", "_fil"]
ht = ["head", "tail"]
for index in range(len(ht)):
for rf in rf_arr:
x_head = []
y_head = []
unique_pairs = set()
dup_count = 0
for i,r in enumerate(tqdm(res)):
if (r['rel'], r[ht[(index+1)%2]]) not in unique_pairs:
unique_pairs.add((r['rel'],r[ht[(index+1)%2]]))
for rank, (e,s,c) in enumerate(zip(\
r[ht[index]+'_predictions'+rf]['entities'],\
r[ht[index]+'_predictions'+rf]['scores'], \
r[ht[index]+'_predictions'+rf]['correctness'])):
features = []
features.append(r[ht[(index+1)%2]])
features.append(r['rel'])
features.append(e)
features.append(s)
features.append(rank)
if combine_embeddings:
if ht[index] == "tail":
temp = np.array(rel_embeddings[r['rel']], dtype=np.float64) + np.array(embeddings[r[ht[(index+1)%2]]], dtype=np.float64)
features.extend(temp)
else:
temp = np.array(embeddings[r[ht[(index+1)%2]]], dtype=np.float64) - np.array(rel_embeddings[r['rel']], dtype=np.float64)
features.extend(temp)
else:
features.extend(rel_embeddings[r['rel']])
features.extend(embeddings[r[ht[(index+1)%2]]])
features.extend(embeddings[e])
x_head.append(features)
y_head.append(c)
else:
dup_count += 1
# add the x_head and y_head to dictionary
triples['x_' + ht[index] + rf] = x_head
triples['y_' + ht[index] + rf] = y_head
print(ht[index] + " : " + rf)
print("# records : ", len(x_head))
print("# duplicates : ", dup_count)
print("DONE")
ans_file = os.path.basename(args.ansfile)
query_features_combined = ""
if combine_embeddings:
query_features_combined += "-combined"
answers_features_file = result_dir + "data/" + ans_file.split('.')[0] + query_features_combined + ".pkl"
with open(answers_features_file, "wb") as fout:
pickle.dump(triples, fout, protocol = pickle.HIGHEST_PROTOCOL)