Skip to content

Commit

Permalink
Merge pull request #277 from dice-group/batched-evaluation
Browse files Browse the repository at this point in the history
Batched evaluation
  • Loading branch information
sshivam95 authored Nov 29, 2024
2 parents 58aa98c + a5f1648 commit fb436e6
Showing 1 changed file with 167 additions and 63 deletions.
230 changes: 167 additions & 63 deletions dicee/static_funcs_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def make_iterable_verbose(iterable_object, verbose, desc="Default", position=Non


def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re_vocab: Dict[Tuple, List],
info='Eval Starts'):
info='Eval Starts', batch_size=128, chunk_size=1000):
"""
Evaluate model in a standard link prediction task
Expand All @@ -21,78 +21,182 @@ def evaluate_lp(model, triple_idx, num_entities, er_vocab: Dict[Tuple, List], re
:param model:
:param triple_idx:
:param info:
:param batch_size:
:param chunk_size:
:return:
"""
model.eval()
print(info)
print(f'Num of triples {len(triple_idx)}')
print('** Evaluation without batching')
print('** Evaluation with batching')
hits = dict()
reciprocal_ranks = []
# Iterate over test triples
all_entities = torch.arange(0, num_entities).long()
all_entities = all_entities.reshape(len(all_entities), )
# Iterating one by one is not good when you are using batch norm
for i in tqdm(range(0, len(triple_idx))):
# (1) Get a triple (head entity, relation, tail entity
data_point = triple_idx[i]
h, r, t = data_point[0], data_point[1], data_point[2]

# (2) Predict missing heads and tails
x = torch.stack((torch.tensor(h).repeat(num_entities, ),
torch.tensor(r).repeat(num_entities, ),
all_entities), dim=1)

predictions_tails = model(x)
x = torch.stack((all_entities,
torch.tensor(r).repeat(num_entities, ),
torch.tensor(t).repeat(num_entities)
), dim=1)

predictions_heads = model(x)
del x

# 3. Computed filtered ranks for missing tail entities.
# 3.1. Compute filtered tail entity rankings
filt_tails = er_vocab[(h, r)]
# 3.2 Get the predicted target's score
target_value = predictions_tails[t].item()
# 3.3 Filter scores of all triples containing filtered tail entities
predictions_tails[filt_tails] = -np.Inf
# 3.4 Reset the target's score
predictions_tails[t] = target_value
# 3.5. Sort the score
_, sort_idxs = torch.sort(predictions_tails, descending=True)
sort_idxs = sort_idxs.detach()
filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]

# 4. Computed filtered ranks for missing head entities.
# 4.1. Retrieve head entities to be filtered
filt_heads = re_vocab[(r, t)]
# 4.2 Get the predicted target's score
target_value = predictions_heads[h].item()
# 4.3 Filter scores of all triples containing filtered head entities.
predictions_heads[filt_heads] = -np.Inf
predictions_heads[h] = target_value
_, sort_idxs = torch.sort(predictions_heads, descending=True)
sort_idxs = sort_idxs.detach()
filt_head_entity_rank = np.where(sort_idxs == h)[0][0]

# 4. Add 1 to ranks as numpy array first item has the index of 0.
filt_head_entity_rank += 1
filt_tail_entity_rank += 1

rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
# 5. Store reciprocal ranks.
reciprocal_ranks.append(rr)
# print(f'{i}.th triple: mean reciprical rank:{rr}')

# 4. Compute Hit@N
for hits_level in range(1, 11):
res = 1 if filt_head_entity_rank <= hits_level else 0
res += 1 if filt_tail_entity_rank <= hits_level else 0
if res > 0:
hits.setdefault(hits_level, []).append(res)

# Evaluation without Batching
# for i in tqdm(range(0, len(triple_idx))):
# # (1) Get a triple (head entity, relation, tail entity
# data_point = triple_idx[i]
# h, r, t = data_point[0], data_point[1], data_point[2]

# # (2) Predict missing heads and tails
# x = torch.stack((torch.tensor(h).repeat(num_entities, ),
# torch.tensor(r).repeat(num_entities, ),
# all_entities), dim=1)

# predictions_tails = model(x)
# x = torch.stack((all_entities,
# torch.tensor(r).repeat(num_entities, ),
# torch.tensor(t).repeat(num_entities)
# ), dim=1)

# predictions_heads = model(x)
# del x

# # 3. Computed filtered ranks for missing tail entities.
# # 3.1. Compute filtered tail entity rankings
# filt_tails = er_vocab[(h, r)]
# # 3.2 Get the predicted target's score
# target_value = predictions_tails[t].item()
# # 3.3 Filter scores of all triples containing filtered tail entities
# predictions_tails[filt_tails] = -np.Inf
# # 3.4 Reset the target's score
# predictions_tails[t] = target_value
# # 3.5. Sort the score
# _, sort_idxs = torch.sort(predictions_tails, descending=True)
# sort_idxs = sort_idxs.detach()
# filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]

# # 4. Computed filtered ranks for missing head entities.
# # 4.1. Retrieve head entities to be filtered
# filt_heads = re_vocab[(r, t)]
# # 4.2 Get the predicted target's score
# target_value = predictions_heads[h].item()
# # 4.3 Filter scores of all triples containing filtered head entities.
# predictions_heads[filt_heads] = -np.Inf
# predictions_heads[h] = target_value
# _, sort_idxs = torch.sort(predictions_heads, descending=True)
# sort_idxs = sort_idxs.detach()
# filt_head_entity_rank = np.where(sort_idxs == h)[0][0]

# # 4. Add 1 to ranks as numpy array first item has the index of 0.
# filt_head_entity_rank += 1
# filt_tail_entity_rank += 1

# rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
# # 5. Store reciprocal ranks.
# reciprocal_ranks.append(rr)
# # print(f'{i}.th triple: mean reciprical rank:{rr}')

# # 4. Compute Hit@N
# for hits_level in range(1, 11):
# res = 1 if filt_head_entity_rank <= hits_level else 0
# res += 1 if filt_tail_entity_rank <= hits_level else 0
# if res > 0:
# hits.setdefault(hits_level, []).append(res)

# Evaluation with Batching
for batch_start in tqdm(range(0, len(triple_idx), batch_size), desc="Evaluating Batches"):
batch_end = min(batch_start + batch_size, len(triple_idx))
batch_triples = triple_idx[batch_start:batch_end]
batch_size_current = len(batch_triples)

# (1) Extract heads, relations, and tails for the batch
h_batch = torch.tensor([data_point[0] for data_point in batch_triples])
r_batch = torch.tensor([data_point[1] for data_point in batch_triples])
t_batch = torch.tensor([data_point[2] for data_point in batch_triples])

# Initialize score tensors
predictions_tails = torch.zeros(batch_size_current, num_entities)
predictions_heads = torch.zeros(batch_size_current, num_entities)

# Process entities in chunks to manage memory usage
for chunk_start in range(0, num_entities, chunk_size):
chunk_end = min(chunk_start + chunk_size, num_entities)
entities_chunk = all_entities[chunk_start:chunk_end]
chunk_size_current = entities_chunk.size(0)

# (2) Predict missing heads and tails
# Prepare input tensors for tail prediction
x_tails = torch.stack((
h_batch.repeat_interleave(chunk_size_current),
r_batch.repeat_interleave(chunk_size_current),
entities_chunk.repeat(batch_size_current)
), dim=1)

# Predict scores for missing tails
preds_tails = model.forward_triples(x_tails)
preds_tails = preds_tails.view(batch_size_current, chunk_size_current)
predictions_tails[:, chunk_start:chunk_end] = preds_tails
del x_tails

# Prepare input tensors for head prediction
x_heads = torch.stack((
entities_chunk.repeat(batch_size_current),
r_batch.repeat_interleave(chunk_size_current),
t_batch.repeat_interleave(chunk_size_current)
), dim=1)

# Predict scores for missing heads
preds_heads = model.forward_triples(x_heads)
preds_heads = preds_heads.view(batch_size_current, chunk_size_current)
predictions_heads[:, chunk_start:chunk_end] = preds_heads
del x_heads

# Iterating one by one is not good when you are using batch norm
for i in range(0, batch_size_current):
# (1) Get a triple (head entity, relation, tail entity)
h = h_batch[i].item()
r = r_batch[i].item()
t = t_batch[i].item()

# 3. Computed filtered ranks for missing tail entities.
# 3.1. Compute filtered tail entity rankings
filt_tails = er_vocab[(h, r)]
filt_tails_set = set(filt_tails) - {t}
filt_tails_indices = list(filt_tails_set)
# 3.2 Get the predicted target's score
target_value = predictions_tails[i, t].item()
# 3.3 Filter scores of all triples containing filtered tail entities
predictions_tails[i, filt_tails_indices] = -np.Inf
# 3.4 Reset the target's score
predictions_tails[i, t] = target_value
# 3.5. Sort the score
_, sort_idxs = torch.sort(predictions_tails[i], descending=True)
sort_idxs = sort_idxs.detach()
filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]

# 4. Computed filtered ranks for missing head entities.
# 4.1. Retrieve head entities to be filtered
filt_heads = re_vocab[(r, t)]
filt_heads_set = set(filt_heads) - {h}
filt_heads_indices = list(filt_heads_set)
# 4.2 Get the predicted target's score
target_value = predictions_heads[i, h].item()
# 4.3 Filter scores of all triples containing filtered head entities.
predictions_heads[i, filt_heads_indices] = -np.Inf
predictions_heads[i, h] = target_value
_, sort_idxs = torch.sort(predictions_heads[i], descending=True)
sort_idxs = sort_idxs.detach()
filt_head_entity_rank = np.where(sort_idxs == h)[0][0]

# 4. Add 1 to ranks as numpy array first item has the index of 0.
filt_head_entity_rank += 1
filt_tail_entity_rank += 1

rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
# 5. Store reciprocal ranks.
reciprocal_ranks.append(rr)
# print(f'{i}.th triple: mean reciprical rank:{rr}')

# 4. Compute Hit@N
for hits_level in range(1, 11):
res = 1 if filt_head_entity_rank <= hits_level else 0
res += 1 if filt_tail_entity_rank <= hits_level else 0
if res > 0:
hits.setdefault(hits_level, []).append(res)

mean_reciprocal_rank = sum(reciprocal_ranks) / (float(len(triple_idx) * 2))

Expand Down

0 comments on commit fb436e6

Please sign in to comment.