Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add triplet margin for distance functions in TripletEvaluator #2862

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions sentence_transformers/evaluation/TripletEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class TripletEvaluator(SentenceEvaluator):
"""
Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
Checks if distance(sentence, positive_example) < distance(sentence, negative_example).
Checks if distance(sentence, positive_example) + margin < distance(sentence, negative_example).

Example:
::
Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(
positives: list[str],
negatives: list[str],
main_distance_function: str | SimilarityFunction | None = None,
triplet_margins: dict[str, float] | None = None,
name: str = "",
batch_size: int = 16,
show_progress_bar: bool = False,
Expand All @@ -80,6 +81,10 @@ def __init__(
main_distance_function (Union[str, SimilarityFunction], optional):
The distance function to use. If not specified, use cosine similarity,
dot product, Euclidean, and Manhattan. Defaults to None.
triplet_margins (Dict[str, float], optional): Margins for various distance metrics.
Acceptable keys are 'cosine', 'dot', 'manhattan', and 'euclidean'. Each value
specifies the minimum margin by which the negative sample should be further from
the anchor than the positive sample. Defaults to None.
name (str): Name for the output. Defaults to "".
batch_size (int): Batch size used to compute embeddings. Defaults to 16.
show_progress_bar (bool): If true, prints a progress bar. Defaults to False.
Expand All @@ -99,6 +104,13 @@ def __init__(

self.main_distance_function = SimilarityFunction(main_distance_function) if main_distance_function else None

default_margins = {"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0}
self.triplet_margins = default_margins if triplet_margins is None else {**default_margins, **triplet_margins}

assert set(self.triplet_margins.keys()) == set(
default_margins.keys()
), "The keys in 'triplet_margins' must be a subset of {'cosine', 'dot', 'manhattan', 'euclidean'}."

self.batch_size = batch_size
if show_progress_bar is None:
show_progress_bar = (
Expand Down Expand Up @@ -184,16 +196,16 @@ def __call__(
for idx in range(len(pos_cos_distance)):
num_triplets += 1

if pos_cos_distance[idx] < neg_cos_distances[idx]:
if pos_cos_distance[idx] + self.triplet_margins["cosine"] < neg_cos_distances[idx]:
num_correct_cos_triplets += 1

if pos_dot_distance[idx] < neg_dot_distances[idx]:
if pos_dot_distance[idx] + self.triplet_margins["dot"] < neg_dot_distances[idx]:
num_correct_dot_triplets += 1

if pos_manhattan_distance[idx] < neg_manhattan_distances[idx]:
if pos_manhattan_distance[idx] + self.triplet_margins["manhattan"] < neg_manhattan_distances[idx]:
num_correct_manhattan_triplets += 1

if pos_euclidean_distance[idx] < neg_euclidean_distances[idx]:
if pos_euclidean_distance[idx] + self.triplet_margins["euclidean"] < neg_euclidean_distances[idx]:
num_correct_euclidean_triplets += 1

accuracy_cos = num_correct_cos_triplets / num_triplets
Expand Down