Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deduplication in GISTEmbedLoss and CachedGISTEmbedLoss #3074

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

JINO-ROHIT
Copy link
Contributor

@JINO-ROHIT JINO-ROHIT commented Nov 19, 2024

Fixes #2756

This is an enhancement issue in discussion for the duplicate positives in GISTEmbedLoss and CachedGISTEmbedLoss , all wandb results logged here - https://wandb.ai/jinooo/sentence-transformers?nw=nwuserjinooo

@tomaarsen

@tomaarsen
Copy link
Collaborator

cc @yjoonjang as you also did some brainstorming on this topic

@tomaarsen tomaarsen linked an issue Nov 19, 2024 that may be closed by this pull request
@yjoonjang
Copy link

Great work. Thank you!

@tomaarsen
Copy link
Collaborator

I did some more tests surrounding #3063 vs this PR.
I found that

ap_sim[guided_ap_sim >= guided_sim] = -torch.inf
aa_sim[guided_aa_sim >= guided_sim] = -torch.inf
pp_sim[guided_pp_sim >= guided_sim] = -torch.inf

results in 0 loss and the model doesn't learn - not ideal.

Beyond that, this PR gets the same performance as my baseline in my tests, this could be because duplicates aren't necessarily a concern. However, I would still like to remove the duplicates if possible.

I also tested with

ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim >= guided_sim] = -torch.inf
pp_sim[guided_pp_sim >= guided_sim] = -torch.inf

which also performed equivalently.

For each sample, the model learns with:

sim(a_1, p_1), sim(a_1, p_2), ..., sim(a_1, p_k), sim(a_1, a_1), sim(a_1, a_2), ..., sim(a_1, a_k), sim(p_1, p_1), sim(p_1, p_2), ..., sim(p_1, p_k)

And then for anchor $i$, we want the $i$-th similarity to be high, while the others must be low. I.e. we want sim(a_i, p_i) to be high, and the others low. But this also contains sim(p_i, p_i) which is always 1 and sim(a_i, a_i) which is also always 1.

Your PR will remove the first one because we now do >= over the pp_sim, but I think we can also do >= over the aa_sim to get rid of the sim(a_i, a_i).

My guess is that didn't actually end up affecting the model training because e.g. sim(a_i, a_i) has no gradient associated with it, i.e. there's no direction to nudge the model in to lower sim(a_i, a_i) down from 1.

In short, my suggestion is to also add >= to aa_sim. What do you think about that @JINO-ROHIT

My test models:

I didn't complete my run with >= on all 3, but the loss was constantly 0 and it didn't learn.

  • Tom Aarsen

@yjoonjang
Copy link

I did some more tests surrounding #3063 vs this PR. I found that

ap_sim[guided_ap_sim >= guided_sim] = -torch.inf
aa_sim[guided_aa_sim >= guided_sim] = -torch.inf
pp_sim[guided_pp_sim >= guided_sim] = -torch.inf

results in 0 loss and the model doesn't learn - not ideal.

Beyond that, this PR gets the same performance as my baseline in my tests, this could be because duplicates aren't necessarily a concern. However, I would still like to remove the duplicates if possible.

I also tested with

ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim >= guided_sim] = -torch.inf
pp_sim[guided_pp_sim >= guided_sim] = -torch.inf

which also performed equivalently.

For each sample, the model learns with:

sim(a_1, p_1), sim(a_1, p_2), ..., sim(a_1, p_k), sim(a_1, a_1), sim(a_1, a_2), ..., sim(a_1, a_k), sim(p_1, p_1), sim(p_1, p_2), ..., sim(p_1, p_k)

And then for anchor i , we want the i -th similarity to be high, while the others must be low. I.e. we want sim(a_i, p_i) to be high, and the others low. But this also contains sim(p_i, p_i) which is always 1 and sim(a_i, a_i) which is also always 1.

Your PR will remove the first one because we now do >= over the pp_sim, but I think we can also do >= over the aa_sim to get rid of the sim(a_i, a_i).

My guess is that didn't actually end up affecting the model training because e.g. sim(a_i, a_i) has no gradient associated with it, i.e. there's no direction to nudge the model in to lower sim(a_i, a_i) down from 1.

In short, my suggestion is to also add >= to aa_sim. What do you think about that @JINO-ROHIT

My test models:

I didn't complete my run with >= on all 3, but the loss was constantly 0 and it didn't learn.

  • Tom Aarsen

I agree with you, @tomaarsen .
And I also want to hear your thoughts about the equal sign being added to the negative calculation which is in https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/losses/CachedGISTEmbedLoss.py#L274C1-L279C65

The code is now

if len(concatenated_reps) > 2:
  for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
    neg_sim[guided_neg_sim > guided_sim] = -torch.inf
    scores = torch.cat([scores, neg_sim], dim=1)

But I thought it should be reformatted to

if len(concatenated_reps) > 2:
  for i in range(2, len(concatenated_reps)):  # Start from 2 since first 2 are anchor-positive
    guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
    neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
    neg_sim[guided_neg_sim >= guided_sim] = -torch.inf # equal sign added
    scores = torch.cat([scores, neg_sim], dim=1)

Since the hard negatives are mined from positives, I think this might reduce the duplicates effectively.
What do you think about it ?

  • Youngjoon Jang

Copy link
Collaborator

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you perhaps add >= for the 3 aa_sim cases as well?

  • Tom Aarsen

@yjoonjang
Copy link

yjoonjang commented Nov 20, 2024

Sure @tomaarsen , should I add >= to neg_sim as well explained above?

  • Youngjoon Jang

@yjoonjang
Copy link

yjoonjang commented Nov 20, 2024

+ Am I able to revise this code?

@tomaarsen
Copy link
Collaborator

tomaarsen commented Nov 20, 2024

Sure @tomaarsen , should I add >= to neg_sim as well explained above?

* Youngjoon Jang

Oh, apologies, I missed your reply #3074 (comment) when I made added my Changes Requested message, it was intended for @JINO-ROHIT.

I think only JINO and myself can make changes to this PR, so I'll leave it to JINO.

As for your negatives proposal - I think that the neg_sim in your code only refers to similarities between the anchors and the negative texts. So if we use >= there, then the only difference is that we won't train with a case where the anchor and a negative are the same. But as mentioned above, it won't train with those anyways because a_i == n_j, then sim(a_i, n_j) == 1 and there's no way to decrease that.

But you're right that we're missing a case here: if positive_i == negative_j, then we're training to increase sim(a_i, p_i) while training to decrease sim(a_i, n_j), except those are the same so it's contradictory, not great for learning.
To fix this, we indeed have to set neg_sim[...] = -torch.inf, but we have to select only the cases where p_i == n_j.

P.s. this is easier to figure out first for GISTEmbedLoss & then propagate it to CachedGISTEmbedLoss, the former is a bit easier to understand.

  • Tom Aarsen

@tomaarsen
Copy link
Collaborator

tomaarsen commented Nov 20, 2024

Actually, because guided_sim denotes the similarity between anchor_i and positive_i, if a value in guided_an_sim is identical to guided_sim, then it's indeed quite possible that positive_i == negative_j and that's exactly the case that we want to throw away.

# Handle the case where we have a negative sample
if negative is not None:
an_sim = self.sim_matrix(anchor, negative)
guided_an_sim = self.sim_matrix(anchor_guide, negative_guide)
an_sim[guided_an_sim > guided_sim] = -torch.inf
scores.append(an_sim)

So maybe >= does just fix it?

Apologies for the confusion.

I'm training a new model with a triplet dataset and >= instead of > as @yjoonjang proposed. I'll have details in about an hour.

@tomaarsen
Copy link
Collaborator

tomaarsen commented Nov 20, 2024

Here are my results:

Sadly the baseline gets stronger performance (0.4104 vs 0.4066 NDCG@10, higher is better).

Also on an unseen part of gooaq the baseline is better:

Baseline: 0.8350 NDCG@10
GTE Test: 0.8302 NDCG@10

(Granted, this is n=1, so it's hard to say that the baseline is always better here)

I'm not sure about the intuition of why the 3 >= would be worse.

Perhaps because the >= doesn't just mean that we're skipping identical texts, but also texts with the same similarity that are actually providing useful negatives.

  • Tom Aarsen

@JINO-ROHIT
Copy link
Contributor Author

we might as well let this PR be then , until we can zero in with more expts.

@yjoonjang
Copy link

Hmm.. interesting.
I don't really know why the performance degrades.
Firstly thought this happened due to small batch size, but 2048 looks big enough for me.

@tomaarsen tomaarsen marked this pull request as draft November 26, 2024 12:58
@daegonYu
Copy link
Contributor

daegonYu commented Dec 2, 2024

@tomaarsen

If I may offer my cautious opinion, wouldn't the issues you mentioned above be resolved by this code: "batch_sampler=BatchSamplers.NO_DUPLICATES"?

@tomaarsen
Copy link
Collaborator

Well spotted! Yes, it would be. If you use that batch sampler, then this won't be an issue to begin with. I mention this briefly in the comments here: #2756 (comment)

    # batch_sampler=BatchSamplers.NO_DUPLICATES,  # Although the loss benefits from having no duplicate samples in a batch
	# we want to specifically test with duplicate samples as those should start being ignored.

I figured that it may be possible to avoid the duplicate samples always, but it's seemingly not as simple as I thought.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants