Skip to content

Commit

Permalink
Merge pull request #47 from sbintuitions/fix/sts_bfloat16_tensor
Browse files Browse the repository at this point in the history
[Fix] STSにおいてbfloat16 tensorが起こった不具合
  • Loading branch information
lsz05 authored Jul 31, 2024
2 parents a3a45f1 + ae8d872 commit 7881bd7
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/jmteb/evaluators/sts/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def _compute_similarity(
embeddings1: Tensor, embeddings2: Tensor, golden_scores: list, similarity_func: Callable
) -> tuple[dict[str, float], list[float]]:
sim_scores = similarity_func(embeddings1, embeddings2).cpu()
if isinstance(sim_scores, Tensor) and sim_scores.dtype is torch.bfloat16:
sim_scores = sim_scores.float()
pearson = pearsonr(golden_scores, sim_scores)[0]
spearman = spearmanr(golden_scores, sim_scores)[0]
return {
Expand Down

0 comments on commit 7881bd7

Please sign in to comment.