Skip to content

Commit a5f1648

Browse files
committed
Comment old code
1 parent 9b17326 commit a5f1648

File tree

1 file changed

+64
-2
lines changed

1 file changed

+64
-2
lines changed

dicee/static_funcs_training.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,69 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
3535
all_entities = torch.arange(0, num_entities).long()
3636
all_entities = all_entities.reshape(len(all_entities), )
3737

38-
# Iterate over test triples in batches
38+
# Evaluation without Batching
39+
# for i in tqdm(range(0, len(triple_idx))):
40+
# # (1) Get a triple (head entity, relation, tail entity
41+
# data_point = triple_idx[i]
42+
# h, r, t = data_point[0], data_point[1], data_point[2]
43+
44+
# # (2) Predict missing heads and tails
45+
# x = torch.stack((torch.tensor(h).repeat(num_entities, ),
46+
# torch.tensor(r).repeat(num_entities, ),
47+
# all_entities), dim=1)
48+
49+
# predictions_tails = model(x)
50+
# x = torch.stack((all_entities,
51+
# torch.tensor(r).repeat(num_entities, ),
52+
# torch.tensor(t).repeat(num_entities)
53+
# ), dim=1)
54+
55+
# predictions_heads = model(x)
56+
# del x
57+
58+
# # 3. Computed filtered ranks for missing tail entities.
59+
# # 3.1. Compute filtered tail entity rankings
60+
# filt_tails = er_vocab[(h, r)]
61+
# # 3.2 Get the predicted target's score
62+
# target_value = predictions_tails[t].item()
63+
# # 3.3 Filter scores of all triples containing filtered tail entities
64+
# predictions_tails[filt_tails] = -np.Inf
65+
# # 3.4 Reset the target's score
66+
# predictions_tails[t] = target_value
67+
# # 3.5. Sort the score
68+
# _, sort_idxs = torch.sort(predictions_tails, descending=True)
69+
# sort_idxs = sort_idxs.detach()
70+
# filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]
71+
72+
# # 4. Computed filtered ranks for missing head entities.
73+
# # 4.1. Retrieve head entities to be filtered
74+
# filt_heads = re_vocab[(r, t)]
75+
# # 4.2 Get the predicted target's score
76+
# target_value = predictions_heads[h].item()
77+
# # 4.3 Filter scores of all triples containing filtered head entities.
78+
# predictions_heads[filt_heads] = -np.Inf
79+
# predictions_heads[h] = target_value
80+
# _, sort_idxs = torch.sort(predictions_heads, descending=True)
81+
# sort_idxs = sort_idxs.detach()
82+
# filt_head_entity_rank = np.where(sort_idxs == h)[0][0]
83+
84+
# # 4. Add 1 to ranks as numpy array first item has the index of 0.
85+
# filt_head_entity_rank += 1
86+
# filt_tail_entity_rank += 1
87+
88+
# rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
89+
# # 5. Store reciprocal ranks.
90+
# reciprocal_ranks.append(rr)
91+
# # print(f'{i}.th triple: mean reciprical rank:{rr}')
92+
93+
# # 4. Compute Hit@N
94+
# for hits_level in range(1, 11):
95+
# res = 1 if filt_head_entity_rank <= hits_level else 0
96+
# res += 1 if filt_tail_entity_rank <= hits_level else 0
97+
# if res > 0:
98+
# hits.setdefault(hits_level, []).append(res)
99+
100+
# Evaluation with Batching
39101
for batch_start in tqdm(range(0, len(triple_idx), batch_size), desc="Evaluating Batches"):
40102
batch_end = min(batch_start + batch_size, len(triple_idx))
41103
batch_triples = triple_idx[batch_start:batch_end]
@@ -85,7 +147,7 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
85147

86148
# Iterating one by one is not good when you are using batch norm
87149
for i in range(0, batch_size_current):
88-
# (1) Get a triple (head entity, relation, tail entity
150+
# (1) Get a triple (head entity, relation, tail entity)
89151
h = h_batch[i].item()
90152
r = r_batch[i].item()
91153
t = t_batch[i].item()

0 commit comments

Comments
 (0)