-
Notifications
You must be signed in to change notification settings - Fork 582
Linear Cross Entropy #2507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Linear Cross Entropy #2507
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2507
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @Andrei-Aksionov! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
There are a couple of things that prevent this from moving from draft stage.
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
hey @Andrei-Aksionov , thanks for the PR! impressive numbers. Do you have any intuition for why tokens per second would double? From their paper, torch.compile is actually the fastest. I dont know if in their paper, when they compared vs Torchtune, they used compile=True. But, in any case, I wouldnt imagine that changing the the output_layer + CE takes would have such TPS impact (but my intuition here may be wrong, since its a 2B model + lora, maybe computing the output_layer for large vocab is the most expensive thing) https://arxiv.org/pdf/2411.09009 If you have bandwidth, could you try with llama 8b for a few steps CCE+compile vs cutloss compile? |
Hello @felipemello1
Well, I guess that's because it's Gemma 2 with only 2b parameters, but the vocab size is the same as for larger models from Gemma 2 family, so applying classification head and calculating loss has a significant effect on training speed. Plus it's a LoRA variant, so it makes it even more pronounced. As I understand, in case of a regular cross-entropy (chunked or not), the process is:
As you can see, a lot of copying operations compared to number of floating point operations. LCE, on the other hand:
As you can see, the proportion of The most interesting part to me, is that the That said, it could be a better option, since no additional dependencies required and no custom triton kernels.
I initially thought it might be unfair to compare a regular, non-compiled model with LCE. To address this, I included a compiled version in the comparison as well.
I'll try to. |
Hey @felipemello1 I've rerun test runs without compilation for both implementations of LCE (for completeness).
As you can see, even without a compilation, it's still ~2x faster than non-compiled CCE. I wasn't able to run Llama 3.1 8b, since I don't have access to the repo. Tried to request access a couple of times and now I'm blocked foreeeeeeveeer.
![]() Here the difference is less pronounced, since now the vocab size doesn't dwarf everything else, so doesn't have such a significant impact. Again, what's interesting, is the
With larger and larger models the memory savings will become negligible, but the improvements in the speed of training is a nice thing. What do you think, @felipemello1? |
hey @Andrei-Aksionov , this is amazing. Thanks for all the insightful experiments and comments. I think that, as a rule of thumb, we would be way more comfortable to add a torch-only implementation to torchtune than take dependency on a new repo. I don't know if their license allows copying it, but horace implemented an early version of pytorch-only CCE and shared it on twitter. Maybe we could use that? We are a bit swamped now with some other tasks, but I dont want to keep you waiting. Can i get back you next week after i meet with the team and check if/how we want to add this loss? The way you did works, but a) it adds some extra logic to the recipe (which is already a bit bloated), b) it modifies the model. So let me check what others think and we can get back to it. Does that sound like a plan? |
Also, food for thought: Many other losses in torchtune ended up implementing the chunked version (e.g. knowledge distilation and GRPO). Although the fused version is more efficient, i wonder how easy it would be for someone to try to repurpose it. Maybe we would need to keep the current chunked version around. |
That was my first idea, but I decided to not go this route since it's training regime specific, so it was not clear for me why the architecture of the model should care about it.
Yes, it does. 🙂 |
self._loss_fn._orig_mod.__class__.__name__ | ||
if hasattr(self._loss_fn, "_orig_mod") | ||
else self._loss_fn.__class__.__name__ | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob question: what's the issue with always using self._loss_fn.__class__.__name__
?
from torchtune import config
# Create an empty dictionary
short_dict = {
"loss1":
{
"_component_": "torchtune.modules.loss.CEWithChunkedOutputLoss",
},
"loss2":
{
"_component_": "cut_cross_entropy.LinearCrossEntropy",
}
}
# Convert the empty dictionary to an OmegaConf DictConfig object
from omegaconf import OmegaConf
cfg = OmegaConf.create(short_dict)
loss1 = config.instantiate(cfg.loss1)
loss2 = config.instantiate(cfg.loss2)
print(loss1.__class__.__name__)
print(loss2.__class__.__name__)
The above toy example prints
CEWithChunkedOutputLoss
LinearCrossEntropy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, @nathan-az
Unfortunately, I didn't fully understand the question 🙃.
Are you asking why do we need to know the name of the loss function in the first place?
If so, then we need to know it since different loss functions require specific changes in the model.
If not, then a clarification would be nice 😊.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess the question is why do we need "_orig_mod". Is that right? I have no clue.
Btw, we can have a nicer abstraction to this. Maybe have the losses follow a protocol, and we can check something like: "hasattr(loss, module_to_compile)", compile(loss.module_to_compile), else compile(loss).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
orig_mod
is needed because if I compile the loss with torch.compile
, then the class will be wrapped into OptimizedModule
(for LinearCrossEntropy). So thus I need to go deeper to retrieve the proper name.
There are many ways to make this PR nicer.
The whole logic could be wrapped in a new loss class, and in the training recipe only something like will be added
if loss_class_name == "LinearCrossEntropy":
self.loss_fn.prepare_model(model)
and the rest will be the same. Including loss_step
function.
Or, if the core team decides to keep a fused version of the loss (like in torch_compile impl
) variant, then we could have FusedLoss
class that will contain any loss function and in the forward call will do the trick, the yaml file will have fused
argument:
loss:
__component__: ...
fused: True
I just need to know what is the decision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying @felipemello1 and @Andrei-Aksionov - that does clarify and answer my question.
Agreed that a nicer abstraction is desired here. The manual handling of the output layer, logits and losses is an unfortunate side-effect too. I wonder if there's a nicer pattern for handling that.
FYI: I pulled your branch to test whether the gains were significant with larger models, and made the same adjustments as this PR to the I don't know how to look into this further, but details are below. Docker: Hardware: 8x Nvidia H100 80GB Error (rank0)
Config batch_size: 1
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: '00004'
model_type: LLAMA3
output_dir: ${output_dir}
recipe_checkpoint: null
checkpoint_dir: inputs/model
clip_grad_norm: null
compile: true
cudnn_deterministic_mode: false
custom_sharded_layers: []
dataset:
_component_: torchtune.datasets.chat_dataset
source: parquet
conversation_column: messages
conversation_style: openai
split: train
packed: true
train_on_input: true
data_dir: inputs/dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: cut_cross_entropy.LinearCrossEntropy
max_steps_per_epoch: 20
metric_logger:
_component_: torchtune.training.metric_logging.MLFlowLogger
optimizer:
_component_: torchao.prototype.low_bit_optim.AdamW8bit
lr: 5.0e-05
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tokenizer:
max_seq_len: 2048
path: inputs/model/original/tokenizer.model
_component_: torchtune.models.llama3.llama3_tokenizer
model:
_component_: torchtune.models.llama3_1.llama3_1_8b This may be a side-effect that I don't understand of being in a multi-device setting. I recall that |
Hey @nathan-az From the error log I can see that the issue is in the Could you change the config to the following loss:
_component_: cut_cross_entropy.LinearCrossEntropy
impl: torch_compile and try it one more time? With this implementation, a special function of only PyTorch code wrapped in |
@Andrei-Aksionov I was able to make a pretty minimal repro so created apple/ml-cross-entropy#31. I also tried with compiling the loss but there was no change. Maybe it'll be obvious to the maintainers what's going wrong! The above even fails with I even wonder if this is a PyTorch Core issue, caused by some funky interaction between FSDP and triton. EDIT: In the repro, this is "fixed" by calling |
@nathan-az , could you try setting reshard_after_forward=False? If you could create a simple repro with just the loss and some dummy tensor, that would make it easier to ping someone from torch distributed. Thanks for investigating! |
Thanks @nathan-az for investigating this 🎉 loss:
_component_: cut_cross_entropy.LinearCrossEntropy
impl: torch_compile is basically loss = LinearCrossEntropy(imlp="torch_compile") and not loss = torch.compile(LinearCrossEntropy) The difference is that the default implementation uses custom Triton kernels, while I believe we are more interested in this implementation to work, since it is more aligned with the goal of the repo to provide PyTorch native code. In addition, it will make it easier to swap a loss function if needed. |
@Andrei-Aksionov , thanks again for all of the work on this! I am trying to think about the requirements to land it, but i think it will need some refactoring on torchtune side. requirements:
Let me refactor some things today and it should make it easier for you to land your PR. Do you mind working on items 0/3, i.e. making sure that we don't add new dependencies, if possible, and it works with distributed? |
@Andrei-Aksionov using the @felipemello1 neither did disabling reshard, although this was a good suggestion that I expected to work. However, I believe this would cause the entire model to accumulate in memory during the forward pass (correct me if I'm wrong on how that arg works), which seems counter to the benefit of the merged kernel. On the license note - the Liger Kernel repo also has a linear cross entropy implementation. Their repo is BSD-2 so very permissive. I'm not sure if there are any performance difference. From what I've seen on the Issue I posted, the Apple CCE maintainers are very responsive and super helpful, but the Liger implementation could be a backup. Best of luck getting this going! Apologies for raising the hassle in the distributed case 😅 |
hey @Andrei-Aksionov @nathan-az , i proposed a refactor here: #2531 @Andrei-Aksionov , this should help with trying liger or cut CE. I wonder if with this setup the user can directly instantiate the loss and use it, or if we need some sort of LossAdapter to avoid changing the recipe and all_gather the weights / redirect the inputs (eg weight,outputs.label) --> (output, weight, label) In the config it could be something like:
But I am afraid that this may be an overkill and would rather not go down that path. Would prefer to have a torchtune version and/or work with the compile team to match the CE perf by doing torch.compile(loss) btw: someone used liger with tune and wrote a blog https://pytorch.org/blog/peak-performance-minimized-memory/ |
Hi there 👋
This PR adds support for Linear Cross Entropy loss (a.k.a. Cut Cross Entropy) proposed by Apple in CUT YOUR LOSSES
IN LARGE-VOCABULARY LANGUAGE MODELS paper by adding support of the official implementation.
So, from now on, we have:
CCE
: chunked cross entropyLCE
: linear cross entropyThe benefit of LCE is that it reduces VRAM usage, especially for models with large vocab sizes.
This is done with 3 things:
Custom Kernels.
LCE uses custom CUDA kernels to perform matrix multiplications and the log-sum-exp reduction in shared memory.
Usually you need to first calculate logits by multiplying embeddings (B, T, hidden_dim) with the output layer (hidden_dim, vocab_size) and then calculate loss with these logits and target labels.
LCE doesn't materialize this matrix with logits in the global memory (this matrix might be huge with modern LLMs that have large vocab size), but rather multiplies only a part (at a time) of output layer weights with logits in the shared (fast) memory and immediately calculates loss there (flash attention playbook).
This avoids materializing logits matrix (saves memory) and saves memory bandwidth (no need to move data from the global to shared memory and back multiple times).
Gradient filtering.
LCE leverages the sparsity of the softmax gradient to skip elements of the gradient computation that have a negligible contribution. This improves the throughput of LCE by reducing unnecessary computation.
Vocabulary sorting.
LCE uses vocabulary sorting to group tokens with similar average logits together. This increases the block-level sparsity and improves the effectiveness of gradient filtering.
I ran a quick single device LoRA recipe with
gemma 2 2b
model.This is the best possible case for LCE, since this model is relatively small, but has the same size of the vocab as larger models of gemma 2 family.
I ran 4 experiments:
CCE (non-compiled)
: chunked cross-entropy without compilation of model and lossCCE
: chunked cross-entropy with compilationLCE
: linear cross entropy with compilationLCE (torch_compile impl)
: an implementation for older GPUs and MPS devicesNote
To do this, one needs to change loss section of a config
As one can see, the loss chart is identical in all cases, yet LCE significantly reduced reserved memory size and reduced time for training.
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example