Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KD] add uld and jsd #2253

Open
wants to merge 1 commit into
base: kd-trainer
Choose a base branch
from
Open

Conversation

kashif
Copy link

@kashif kashif commented Jan 11, 2025

Description

Add ULD loss and JSD KD Loss

Paper

https://arxiv.org/abs/2402.12030

@winglian winglian force-pushed the kd-trainer branch 2 times, most recently from 2dcbc0d to 35a84f2 Compare January 15, 2025 04:30
Comment on lines +60 to +70
# Get masked student probabilities
student_probs_masked = student_probs[valid_mask]

# Get masked teacher probabilities
teacher_probs_masked = teacher_probs[valid_mask]

# Sort student probabilities in descending order
student_probs_sorted, _ = torch.sort(student_probs_masked, dim=-1, descending=True)

# For teacher probs, we already have top-K, so just ensure they're sorted
teacher_probs_sorted, _ = torch.sort(teacher_probs_masked, dim=-1, descending=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif since the token_ids don't match, would it be better to sort, then. just take top_k of the student distribution?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so let me check with the definition of wasserstein loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants