@@ -35,7 +35,69 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
35
35
all_entities = torch .arange (0 , num_entities ).long ()
36
36
all_entities = all_entities .reshape (len (all_entities ), )
37
37
38
- # Iterate over test triples in batches
38
+ # Evaluation without Batching
39
+ # for i in tqdm(range(0, len(triple_idx))):
40
+ # # (1) Get a triple (head entity, relation, tail entity
41
+ # data_point = triple_idx[i]
42
+ # h, r, t = data_point[0], data_point[1], data_point[2]
43
+
44
+ # # (2) Predict missing heads and tails
45
+ # x = torch.stack((torch.tensor(h).repeat(num_entities, ),
46
+ # torch.tensor(r).repeat(num_entities, ),
47
+ # all_entities), dim=1)
48
+
49
+ # predictions_tails = model(x)
50
+ # x = torch.stack((all_entities,
51
+ # torch.tensor(r).repeat(num_entities, ),
52
+ # torch.tensor(t).repeat(num_entities)
53
+ # ), dim=1)
54
+
55
+ # predictions_heads = model(x)
56
+ # del x
57
+
58
+ # # 3. Computed filtered ranks for missing tail entities.
59
+ # # 3.1. Compute filtered tail entity rankings
60
+ # filt_tails = er_vocab[(h, r)]
61
+ # # 3.2 Get the predicted target's score
62
+ # target_value = predictions_tails[t].item()
63
+ # # 3.3 Filter scores of all triples containing filtered tail entities
64
+ # predictions_tails[filt_tails] = -np.Inf
65
+ # # 3.4 Reset the target's score
66
+ # predictions_tails[t] = target_value
67
+ # # 3.5. Sort the score
68
+ # _, sort_idxs = torch.sort(predictions_tails, descending=True)
69
+ # sort_idxs = sort_idxs.detach()
70
+ # filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]
71
+
72
+ # # 4. Computed filtered ranks for missing head entities.
73
+ # # 4.1. Retrieve head entities to be filtered
74
+ # filt_heads = re_vocab[(r, t)]
75
+ # # 4.2 Get the predicted target's score
76
+ # target_value = predictions_heads[h].item()
77
+ # # 4.3 Filter scores of all triples containing filtered head entities.
78
+ # predictions_heads[filt_heads] = -np.Inf
79
+ # predictions_heads[h] = target_value
80
+ # _, sort_idxs = torch.sort(predictions_heads, descending=True)
81
+ # sort_idxs = sort_idxs.detach()
82
+ # filt_head_entity_rank = np.where(sort_idxs == h)[0][0]
83
+
84
+ # # 4. Add 1 to ranks as numpy array first item has the index of 0.
85
+ # filt_head_entity_rank += 1
86
+ # filt_tail_entity_rank += 1
87
+
88
+ # rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
89
+ # # 5. Store reciprocal ranks.
90
+ # reciprocal_ranks.append(rr)
91
+ # # print(f'{i}.th triple: mean reciprical rank:{rr}')
92
+
93
+ # # 4. Compute Hit@N
94
+ # for hits_level in range(1, 11):
95
+ # res = 1 if filt_head_entity_rank <= hits_level else 0
96
+ # res += 1 if filt_tail_entity_rank <= hits_level else 0
97
+ # if res > 0:
98
+ # hits.setdefault(hits_level, []).append(res)
99
+
100
+ # Evaluation with Batching
39
101
for batch_start in tqdm (range (0 , len (triple_idx ), batch_size ), desc = "Evaluating Batches" ):
40
102
batch_end = min (batch_start + batch_size , len (triple_idx ))
41
103
batch_triples = triple_idx [batch_start :batch_end ]
@@ -85,7 +147,7 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
85
147
86
148
# Iterating one by one is not good when you are using batch norm
87
149
for i in range (0 , batch_size_current ):
88
- # (1) Get a triple (head entity, relation, tail entity
150
+ # (1) Get a triple (head entity, relation, tail entity)
89
151
h = h_batch [i ].item ()
90
152
r = r_batch [i ].item ()
91
153
t = t_batch [i ].item ()
0 commit comments