-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
deduplication in GISTEmbedLoss and CachedGISTEmbedLoss #3074
base: master
Are you sure you want to change the base?
Conversation
cc @yjoonjang as you also did some brainstorming on this topic |
Great work. Thank you! |
I did some more tests surrounding #3063 vs this PR. ap_sim[guided_ap_sim >= guided_sim] = -torch.inf
aa_sim[guided_aa_sim >= guided_sim] = -torch.inf
pp_sim[guided_pp_sim >= guided_sim] = -torch.inf results in 0 loss and the model doesn't learn - not ideal. Beyond that, this PR gets the same performance as my baseline in my tests, this could be because duplicates aren't necessarily a concern. However, I would still like to remove the duplicates if possible. I also tested with ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim >= guided_sim] = -torch.inf
pp_sim[guided_pp_sim >= guided_sim] = -torch.inf which also performed equivalently. For each sample, the model learns with:
And then for anchor Your PR will remove the first one because we now do My guess is that didn't actually end up affecting the model training because e.g. In short, my suggestion is to also add My test models:
I didn't complete my run with
|
I agree with you, @tomaarsen . The code is now if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1) But I thought it should be reformatted to if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim >= guided_sim] = -torch.inf # equal sign added
scores = torch.cat([scores, neg_sim], dim=1) Since the hard negatives are mined from positives, I think this might reduce the duplicates effectively.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you perhaps add >=
for the 3 aa_sim
cases as well?
- Tom Aarsen
Sure @tomaarsen , should I add
|
+ Am I able to revise this code? |
Oh, apologies, I missed your reply #3074 (comment) when I made added my Changes Requested message, it was intended for @JINO-ROHIT. I think only JINO and myself can make changes to this PR, so I'll leave it to JINO. As for your negatives proposal - I think that the But you're right that we're missing a case here: if P.s. this is easier to figure out first for GISTEmbedLoss & then propagate it to CachedGISTEmbedLoss, the former is a bit easier to understand.
|
Actually, because sentence-transformers/sentence_transformers/losses/GISTEmbedLoss.py Lines 157 to 163 in efbf3ee
So maybe Apologies for the confusion. I'm training a new model with a |
Here are my results:
Sadly the baseline gets stronger performance (0.4104 vs 0.4066 NDCG@10, higher is better). Also on an unseen part of gooaq the baseline is better:
(Granted, this is n=1, so it's hard to say that the baseline is always better here) I'm not sure about the intuition of why the 3 Perhaps because the
|
we might as well let this PR be then , until we can zero in with more expts. |
Hmm.. interesting. |
If I may offer my cautious opinion, wouldn't the issues you mentioned above be resolved by this code: "batch_sampler=BatchSamplers.NO_DUPLICATES"? |
Well spotted! Yes, it would be. If you use that batch sampler, then this won't be an issue to begin with. I mention this briefly in the comments here: #2756 (comment) # batch_sampler=BatchSamplers.NO_DUPLICATES, # Although the loss benefits from having no duplicate samples in a batch
# we want to specifically test with duplicate samples as those should start being ignored. I figured that it may be possible to avoid the duplicate samples always, but it's seemingly not as simple as I thought.
|
Fixes #2756
This is an enhancement issue in discussion for the duplicate positives in GISTEmbedLoss and CachedGISTEmbedLoss , all wandb results logged here - https://wandb.ai/jinooo/sentence-transformers?nw=nwuserjinooo
@tomaarsen