forked from thunlp/OpenKE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
read-answers-batches.py
112 lines (98 loc) · 4.5 KB
/
read-answers-batches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import sys
import json
from random import shuffle
import random
import argparse
import pickle
import numpy as np
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description = 'Read answer embeddings based on topk and prepare triples(training/test) for LSTM/other RNN models.')
parser.add_argument('--embfile', dest ='embfile', type = str, help = 'File containing embeddings.')
parser.add_argument('-rd', '--result-dir', dest ='result_dir', type = str, default = "/var/scratch2/uji300/OpenKE-results/",help = 'Output dir.')
parser.add_argument('--combemb', dest ='combine_emb', help = 'Whether to combine embeddings of ent and rel in input', action = 'store_true')
parser.add_argument('--topk', dest = 'topk', required = True, type = int, default = 10)
parser.add_argument('--bs', dest = 'batch_size', type = int, default = 1000)
parser.add_argument('--db', required = True, dest = 'db', type = str, default = None)
parser.add_argument('--mode', required = True, dest = 'mode', type = str, default = None, help = "train or test")
parser.add_argument('--ansfile', dest ='ansfile', type = str, help = 'File containing answers as predicted by the model.')
return parser.parse_args()
args = parse_args()
topk = args.topk
db = args.db
batch_size = args.batch_size
result_dir = args.result_dir + args.db + "/"
os.makedirs(result_dir, exist_ok = True)
data_dir = result_dir + "batch_data/"
os.makedirs(data_dir, exist_ok = True)
# Read embedding file
print("Reading embeddings file...", end=" ")
with open(args.embfile, "r") as fin:
params = json.loads(fin.read())
embeddings = params['ent_embeddings.weight']
rel_embeddings = params['rel_embeddings.weight']
print("DONE")
# Read the answers file (generated from the test option of the model)
print("Reading answers file...", end=" ")
with open(args.ansfile, "r") as fin:
res = json.loads(fin.read())
print("DONE")
ans_file = os.path.basename(args.ansfile).split('.')[0]
triples = {}
if args.mode == "train":
rf_arr = [""]
else:
rf_arr = ["_raw", "_fil"]
ht = ["head", "tail"]
for index in range(len(ht)):
for rf in rf_arr:
batch_id = 1
x_head = []
y_head = []
unique_pairs = set()
dup_count = 0
for i,r in enumerate(tqdm(res)):
if (r['rel'], r[ht[(index+1)%2]]) not in unique_pairs:
unique_pairs.add((r['rel'],r[ht[(index+1)%2]]))
for rank, (e,s,c) in enumerate(zip(\
r[ht[index]+'_predictions'+rf]['entities'],\
r[ht[index]+'_predictions'+rf]['scores'], \
r[ht[index]+'_predictions'+rf]['correctness'])):
features = []
features.append(r[ht[(index+1)%2]])
features.append(r['rel'])
features.append(e)
features.append(s)
features.append(rank)
features.extend(rel_embeddings[r['rel']])
features.extend(embeddings[r[ht[(index+1)%2]]])
features.extend(embeddings[e])
x_head.append(features)
y_head.append(c)
else:
dup_count += 1
if len(x_head) == batch_size * topk:
# add the x_head and y_head to dictionary
triples['x_' + ht[index] + rf] = x_head
triples['y_' + ht[index] + rf] = y_head
#print("# records : ", len(x_head))
#print("# duplicates : ", dup_count)
#print("DONE")
answers_features_file = data_dir + ans_file + "-"+ ht[index] + rf + "-batch-" + str(batch_id) + ".pkl"
print("Creating " + answers_features_file)
with open(answers_features_file, "wb") as fout:
pickle.dump(triples, fout, protocol = pickle.HIGHEST_PROTOCOL)
batch_id += 1
x_head.clear()
y_head.clear()
triples.clear()
triples['x_' + ht[index] + rf] = x_head
triples['y_' + ht[index] + rf] = y_head
#print("# records : ", len(x_head))
#print("# duplicates : ", dup_count)
#print("DONE")
#TODO: making the batch of remaining queries does not work for fit_generator
answers_features_file = data_dir + ans_file + "-" + ht[index] + "-batch-" + str(batch_id) + ".pkl"
with open(answers_features_file, "wb") as fout:
pickle.dump(triples, fout, protocol = pickle.HIGHEST_PROTOCOL)