Skip to content

Commit

Permalink
Merge branch 'kfold-crossvalidation-verbose' of https://github.com/di…
Browse files Browse the repository at this point in the history
…ce-group/dice-embeddings into kfold-crossvalidation-verbose
  • Loading branch information
sshivam95 committed Nov 26, 2024
2 parents 8b4474a + 11864b2 commit 27d9dd6
Showing 1 changed file with 79 additions and 57 deletions.
136 changes: 79 additions & 57 deletions dicee/eval_static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@torch.no_grad()
def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tuple, List],
re_vocab: Dict[Tuple, List]) -> Dict:
re_vocab: Dict[Tuple, List], batch_size=128) -> Dict:
"""
Parameters
Expand All @@ -16,6 +16,7 @@ def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tup
triples
er_vocab
re_vocab
batch_size
Returns
-------
Expand All @@ -31,67 +32,88 @@ def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tup
all_entities = torch.arange(0, num_entities).long()
all_entities = all_entities.reshape(len(all_entities), )
# Iterating one by one is not good when you are using batch norm
for i in tqdm(range(0, len(triples))):
# (1) Get a triple (head entity, relation, tail entity
data_point = triples[i]
str_h, str_r, str_t = data_point[0], data_point[1], data_point[2]

h, r, t = model.get_entity_index(str_h), model.get_relation_index(str_r), model.get_entity_index(str_t)
# Iterate over test triples in batches
for batch_start in tqdm(range(0, len(triples)), batch_size):
batch_end = min(batch_start + batch_size, len(triples))
batch_triples = triples[batch_start:batch_end]

# Prepare batch data
str_h_batch = [data_point[0] for data_point in batch_triples]
str_r_batch = [data_point[1] for data_point in batch_triples]
str_t_batch = [data_point[2] for data_point in batch_triples]

h_batch = [model.get_entity_index(str_h) for str_h in str_h_batch]
r_batch = [model.get_entity_index(str_r) for str_r in str_r_batch]
t_batch = [model.get_entity_index(str_t) for str_t in str_t_batch]

h_batch_tensor = torch.tensor(h_batch)
r_batch_tensor = torch.tensor(r_batch)
t_batch_tensor = torch.tensor(t_batch)

batch_size_current = len(batch_triples)
num_entities = len(all_entities)

# (2) Predict missing heads and tails
x = torch.stack((torch.tensor(h).repeat(num_entities, ),
torch.tensor(r).repeat(num_entities, ),
all_entities), dim=1)

predictions_tails = model.model.forward_triples(x)
x = torch.stack((all_entities,
torch.tensor(r).repeat(num_entities, ),
torch.tensor(t).repeat(num_entities)
x = torch.stack((h_batch_tensor.repeat_interleave(num_entities), r_batch_tensor.repeat_interleave(num_entities), all_entities.repeat(batch_size_current)), dim=1)
predictions_tails = model.model.forward_triples(x).view(batch_size_current, num_entities)

x = torch.stack((all_entities.repeat(batch_size_current),
r_batch_tensor.repeat_interleave(num_entities),
t_batch_tensor.repeat_interleave(num_entities)
), dim=1)

predictions_heads = model.model.forward_triples(x)
predictions_heads = model.model.forward_triples(x).view(batch_size_current, num_entities)
del x

# 3. Computed filtered ranks for missing tail entities.
# 3.1. Compute filtered tail entity rankings
filt_tails = [model.entity_to_idx[i] for i in er_vocab[(str_h, str_r)]]
# 3.2 Get the predicted target's score
target_value = predictions_tails[t].item()
# 3.3 Filter scores of all triples containing filtered tail entities
predictions_tails[filt_tails] = -np.Inf
# 3.4 Reset the target's score
predictions_tails[t] = target_value
# 3.5. Sort the score
_, sort_idxs = torch.sort(predictions_tails, descending=True)
sort_idxs = sort_idxs.detach()
filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]

# 4. Computed filtered ranks for missing head entities.
# 4.1. Retrieve head entities to be filtered
filt_heads = [model.entity_to_idx[i] for i in re_vocab[(str_r, str_t)]]
# 4.2 Get the predicted target's score
target_value = predictions_heads[h].item()
# 4.3 Filter scores of all triples containing filtered head entities.
predictions_heads[filt_heads] = -np.Inf
predictions_heads[h] = target_value
_, sort_idxs = torch.sort(predictions_heads, descending=True)
sort_idxs = sort_idxs.detach()
filt_head_entity_rank = np.where(sort_idxs == h)[0][0]

# 4. Add 1 to ranks as numpy array first item has the index of 0.
filt_head_entity_rank += 1
filt_tail_entity_rank += 1

rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
# 5. Store reciprocal ranks.
reciprocal_ranks.append(rr)
# print(f'{i}.th triple: mean reciprical rank:{rr}')

# 4. Compute Hit@N
for hits_level in range(1, 11):
res = 1 if filt_head_entity_rank <= hits_level else 0
res += 1 if filt_tail_entity_rank <= hits_level else 0
if res > 0:
hits.setdefault(hits_level, []).append(res)
# Now process each triple in the batch
for i in range(batch_size_current):
h = h_batch[i]
r = r_batch[i]
t = t_batch[i]
str_h = str_h_batch[i]
str_r = str_r_batch[i]
str_t = str_t_batch[i]
# 3. Computed filtered ranks for missing tail entities.
# 3.1. Compute filtered tail entity rankings
filt_tails = [model.entity_to_idx[i] for i in er_vocab[(str_h, str_r)]]
# 3.2 Get the predicted target's score
target_value = predictions_tails[t].item()
# 3.3 Filter scores of all triples containing filtered tail entities
predictions_tails[filt_tails] = -np.Inf
# 3.4 Reset the target's score
predictions_tails[t] = target_value
# 3.5. Sort the score
_, sort_idxs = torch.sort(predictions_tails, descending=True)
sort_idxs = sort_idxs.detach()
filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]

# 4. Computed filtered ranks for missing head entities.
# 4.1. Retrieve head entities to be filtered
filt_heads = [model.entity_to_idx[i] for i in re_vocab[(str_r, str_t)]]
# 4.2 Get the predicted target's score
target_value = predictions_heads[h].item()
# 4.3 Filter scores of all triples containing filtered head entities.
predictions_heads[filt_heads] = -np.Inf
predictions_heads[h] = target_value
_, sort_idxs = torch.sort(predictions_heads, descending=True)
sort_idxs = sort_idxs.detach()
filt_head_entity_rank = np.where(sort_idxs == h)[0][0]

# 4. Add 1 to ranks as numpy array first item has the index of 0.
filt_head_entity_rank += 1
filt_tail_entity_rank += 1

rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
# 5. Store reciprocal ranks.
reciprocal_ranks.append(rr)
# print(f'{i}.th triple: mean reciprical rank:{rr}')

# 4. Compute Hit@N
for hits_level in range(1, 11):
res = 1 if filt_head_entity_rank <= hits_level else 0
res += 1 if filt_tail_entity_rank <= hits_level else 0
if res > 0:
hits.setdefault(hits_level, []).append(res)

mean_reciprocal_rank = sum(reciprocal_ranks) / (float(len(triples) * 2))

Expand Down

0 comments on commit 27d9dd6

Please sign in to comment.