Skip to content

Linear Cross Entropy #2507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Andrei-Aksionov
Copy link

Hi there 👋

This PR adds support for Linear Cross Entropy loss (a.k.a. Cut Cross Entropy) proposed by Apple in CUT YOUR LOSSES
IN LARGE-VOCABULARY LANGUAGE MODELS
paper by adding support of the official implementation.

So, from now on, we have:

  • CCE: chunked cross entropy
  • LCE: linear cross entropy

The benefit of LCE is that it reduces VRAM usage, especially for models with large vocab sizes.

This is done with 3 things:

  1. Custom Kernels.
    LCE uses custom CUDA kernels to perform matrix multiplications and the log-sum-exp reduction in shared memory.
    Usually you need to first calculate logits by multiplying embeddings (B, T, hidden_dim) with the output layer (hidden_dim, vocab_size) and then calculate loss with these logits and target labels.
    LCE doesn't materialize this matrix with logits in the global memory (this matrix might be huge with modern LLMs that have large vocab size), but rather multiplies only a part (at a time) of output layer weights with logits in the shared (fast) memory and immediately calculates loss there (flash attention playbook).
    This avoids materializing logits matrix (saves memory) and saves memory bandwidth (no need to move data from the global to shared memory and back multiple times).

  2. Gradient filtering.
    LCE leverages the sparsity of the softmax gradient to skip elements of the gradient computation that have a negligible contribution. This improves the throughput of LCE by reducing unnecessary computation.

  3. Vocabulary sorting.
    LCE uses vocabulary sorting to group tokens with similar average logits together. This increases the block-level sparsity and improves the effectiveness of gradient filtering.


I ran a quick single device LoRA recipe with gemma 2 2b model.
This is the best possible case for LCE, since this model is relatively small, but has the same size of the vocab as larger models of gemma 2 family.

I ran 4 experiments:

  • CCE (non-compiled): chunked cross-entropy without compilation of model and loss
  • CCE: chunked cross-entropy with compilation
  • LCE: linear cross entropy with compilation
  • LCE (torch_compile impl): an implementation for older GPUs and MPS devices

Note

To do this, one needs to change loss section of a config

loss:
    _component_: cut_cross_entropy.LinearCrossEntropy
    impl: torch_compile # (older than Ampere GPUs or MPS devices)
Name Peak Memory Reserved (GB) Time
CCE (non-compiled) 21.6 10:15
CCE 12.3 09:18
LCE 7.73 04:51
LCE (torch_compile impl) 9.1 04:43
Screenshot 2025-03-17 at 6 55 01 PM

As one can see, the loss chart is identical in all cases, yet LCE significantly reduced reserved memory size and reduced time for training.


Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Mar 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2507

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @Andrei-Aksionov!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@Andrei-Aksionov
Copy link
Author

There are a couple of things that prevent this from moving from draft stage.

  1. What to do with docs? Do I need to create a wrapper class for LinearCrossEntropy with a docstring, so this info is automatically populated in docs?
  2. Due to lack of compute resources (and no access to multiple GPU machines), I tested only single device LoRA recipe.
    I haven't found any mentioning that LCE doesn't work on muiti-GPU setups, yet cannot confirm that.
    What should I do in this case? And what to do with other recipes?

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 17, 2025
@felipemello1
Copy link
Contributor

felipemello1 commented Mar 17, 2025

hey @Andrei-Aksionov , thanks for the PR! impressive numbers. Do you have any intuition for why tokens per second would double? From their paper, torch.compile is actually the fastest. I dont know if in their paper, when they compared vs Torchtune, they used compile=True.

But, in any case, I wouldnt imagine that changing the the output_layer + CE takes would have such TPS impact (but my intuition here may be wrong, since its a 2B model + lora, maybe computing the output_layer for large vocab is the most expensive thing)

https://arxiv.org/pdf/2411.09009
image

If you have bandwidth, could you try with llama 8b for a few steps CCE+compile vs cutloss compile?

@Andrei-Aksionov
Copy link
Author

Hello @felipemello1

Do you have any intuition for why tokens per second would double?

Well, I guess that's because it's Gemma 2 with only 2b parameters, but the vocab size is the same as for larger models from Gemma 2 family, so applying classification head and calculating loss has a significant effect on training speed. Plus it's a LoRA variant, so it makes it even more pronounced.
I would imagine for other models of a bigger size and smaller vocab size the effect might be smaller.

As I understand, in case of a regular cross-entropy (chunked or not), the process is:

  1. Calculate logits by multiplying embeddings with the output layer. If the vocab size is large, there are a lot of copying operations from shared memory (SM) to global memory (GB or HBM).
  2. We need to convert those logits to probabilities, so we need:
    a) load logits block-by-block from HBM into SM to exponentiate values and then return them back.
    b) calculate sum of logits
    c) load exponentiated logits back from HBM into SM to divide them by the sum value, so we can convert them to probabilities
  3. Load them again to calculate the loss

As you can see, a lot of copying operations compared to number of floating point operations.

LCE, on the other hand:

  1. Copies a portion of embeddings and output weights (that can fit into a thread block), HBM --> SM
  2. Multiplies them to get logits, but keeps the result in SM
  3. Converts logits into probabilities without a need of knowing the sum of all exponentiated values, thanks to online softmax normalization trick (the one used in flash attention).
  4. Probably calculates the loss there, but I'm not sure.
  5. Copies the result back SM --> HBM.

As you can see, the proportion of (number of floating point ops) / (bits to copy) is higher than for the regular CE calculation.
That's exactly what you want when you are writing a custom kernel to make everything faster (considering that the default kernel is memory bound).


The most interesting part to me, is that the torch_compile impl version is not that far off, actually.
It's as fast, but consumes a bit more memory (albeit smaller than CCE).
And this is just your regular pytorch code, where in one function embeddings @ output layer + CE calculation, all wrapped in torch.compile decorator. No custom triton kernels.
I would imagine, that torch.compile can fuse these operations together and for softmax operations it uses online softmax normalization trick, which is already implemented in SDPA anyway. So, there is high chance for it.

That said, it could be a better option, since no additional dependencies required and no custom triton kernels.


I dont know if in their paper, when they compared vs Torchtune, they used compile=True.

I initially thought it might be unfair to compare a regular, non-compiled model with LCE. To address this, I included a compiled version in the comparison as well.
For consistency, I also decided to compile the LCE variant. I'm planning to add a non-compiled LCE version for completeness, although I expect the loss function to be compiled automatically anyway (only the model itself will stay non-compiled).

If you have bandwidth, could you try with llama 8b for a few steps CCE+compile vs cutloss compile?

I'll try to.

@Andrei-Aksionov
Copy link
Author

Hey @felipemello1

I've rerun test runs without compilation for both implementations of LCE (for completeness).
(Forgot to mention that it's all done on a single L4.)

Loss Name Peak Memory Reserved (GB) Time Compiled
CCE 21.6 10:15 False
CCE 12.3 09:18 True
------------------------- ------------------------- ------------------------- -------------------------
LCE 8.2 05:53 False
LCE 7.73 04:51 True
------------------------- ------------------------- ------------------------- -------------------------
LCE (torch_compile impl) 8.7 05:40 False
LCE (torch_compile impl) 9.1 04:43 True

As you can see, even without a compilation, it's still ~2x faster than non-compiled CCE.


I wasn't able to run Llama 3.1 8b, since I don't have access to the repo. Tried to request access a couple of times and now I'm blocked foreeeeeeveeer.
So I tried Qwen 2.5 7b. It has a larger vocab size than llama (152k vs 128k) and 1b less parameters, but at least it's something somewhat similar (kinda 🙃).

Loss Name Peak Memory Reserved (GB) Time
CCE 17.5 05:14
LCE 16.2 04:40
LCE (torch_compile impl) 16.9 04:26
Screenshot 2025-03-19 at 8 33 24 PM

Here the difference is less pronounced, since now the vocab size doesn't dwarf everything else, so doesn't have such a significant impact.

Again, what's interesting, is the torch_compile impl:

  1. As fast as LCE implemented with custom Triton kernels
  2. Provides some memory savings, but smaller than the proper Triton variant. (Due to not supporting Gradient filtering.)

With larger and larger models the memory savings will become negligible, but the improvements in the speed of training is a nice thing.
I recommend spending time on researching the impact of a function, that, apparently, makes it easier to fuse last layer matmul and loss calculations. Again, look at the code for torch_compile impl: it's just a function with two operations in it. The benefit is that you can replace F.cross_entropy with all the custom loss functions from this repo.

What do you think, @felipemello1?
And @ebsmothers, of course 😊

@felipemello1
Copy link
Contributor

hey @Andrei-Aksionov , this is amazing. Thanks for all the insightful experiments and comments.

I think that, as a rule of thumb, we would be way more comfortable to add a torch-only implementation to torchtune than take dependency on a new repo. I don't know if their license allows copying it, but horace implemented an early version of pytorch-only CCE and shared it on twitter. Maybe we could use that?

We are a bit swamped now with some other tasks, but I dont want to keep you waiting. Can i get back you next week after i meet with the team and check if/how we want to add this loss?

The way you did works, but a) it adds some extra logic to the recipe (which is already a bit bloated), b) it modifies the model.
What I thought about doing originally was to add to the transformer a flag "skip_output_layer", but this is also bad, as it adds yet another flag to the model.

So let me check what others think and we can get back to it. Does that sound like a plan?

@felipemello1
Copy link
Contributor

Also, food for thought: Many other losses in torchtune ended up implementing the chunked version (e.g. knowledge distilation and GRPO). Although the fused version is more efficient, i wonder how easy it would be for someone to try to repurpose it. Maybe we would need to keep the current chunked version around.

@Andrei-Aksionov
Copy link
Author

What I thought about doing originally was to add to the transformer a flag "skip_output_layer", but this is also bad, as it adds yet another flag to the model.

That was my first idea, but I decided to not go this route since it's training regime specific, so it was not clear for me why the architecture of the model should care about it.


So let me check what others think and we can get back to it. Does that sound like a plan?

Yes, it does. 🙂
Evan mentioned a couple of things that I can take a look at, so I can keep myself busy.

self._loss_fn._orig_mod.__class__.__name__
if hasattr(self._loss_fn, "_orig_mod")
else self._loss_fn.__class__.__name__
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: what's the issue with always using self._loss_fn.__class__.__name__?

from torchtune import config

# Create an empty dictionary
short_dict = {
    "loss1":
        {
            "_component_": "torchtune.modules.loss.CEWithChunkedOutputLoss",
        },
    "loss2":
        {
            "_component_": "cut_cross_entropy.LinearCrossEntropy",
        }
}

# Convert the empty dictionary to an OmegaConf DictConfig object
from omegaconf import OmegaConf

cfg = OmegaConf.create(short_dict)
loss1 = config.instantiate(cfg.loss1)
loss2 = config.instantiate(cfg.loss2)

print(loss1.__class__.__name__)
print(loss2.__class__.__name__)

The above toy example prints

CEWithChunkedOutputLoss
LinearCrossEntropy

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, @nathan-az
Unfortunately, I didn't fully understand the question 🙃.
Are you asking why do we need to know the name of the loss function in the first place?
If so, then we need to know it since different loss functions require specific changes in the model.
If not, then a clarification would be nice 😊.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess the question is why do we need "_orig_mod". Is that right? I have no clue.

Btw, we can have a nicer abstraction to this. Maybe have the losses follow a protocol, and we can check something like: "hasattr(loss, module_to_compile)", compile(loss.module_to_compile), else compile(loss).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orig_mod is needed because if I compile the loss with torch.compile, then the class will be wrapped into OptimizedModule (for LinearCrossEntropy). So thus I need to go deeper to retrieve the proper name.

There are many ways to make this PR nicer.
The whole logic could be wrapped in a new loss class, and in the training recipe only something like will be added

if loss_class_name == "LinearCrossEntropy":
    self.loss_fn.prepare_model(model)

and the rest will be the same. Including loss_step function.

Or, if the core team decides to keep a fused version of the loss (like in torch_compile impl) variant, then we could have FusedLoss class that will contain any loss function and in the forward call will do the trick, the yaml file will have fused argument:

loss:
    __component__: ...
    fused: True

I just need to know what is the decision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying @felipemello1 and @Andrei-Aksionov - that does clarify and answer my question.

Agreed that a nicer abstraction is desired here. The manual handling of the output layer, logits and losses is an unfortunate side-effect too. I wonder if there's a nicer pattern for handling that.

@nathan-az
Copy link
Contributor

nathan-az commented Mar 26, 2025

FYI: I pulled your branch to test whether the gains were significant with larger models, and made the same adjustments as this PR to the full_finetune_distributed recipe since I only have access to multi-GPU nodes currently. I consistently got Illegal Memory Access errors. This PR only proposes changes to the single device recipe, but I think it's worth flagging that there appear to be issues in multi-device.

I don't know how to look into this further, but details are below.

Docker: ghcr.io/pytorch/pytorch-nightly:2.7.0.dev20250221-cuda12.4-cudnn9-devel (I have found this version otherwise stable)

Hardware: 8x Nvidia H100 80GB

Error (rank0)

[rank0]:   File "/app/cce_test.py", line 984, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/app/cce_test.py", line 835, in train
[rank0]:     current_loss = self._loss_fn(logits, self._output_weights, labels)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/cut_cross_entropy/linear_cross_entropy.py", line 126, in forward
[rank0]:     return linear_cross_entropy(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/cut_cross_entropy/linear_cross_entropy.py", line 71, in linear_cross_entropy
[rank0]:     return cce_linear_cross_entropy(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/cut_cross_entropy/cce.py", line 196, in cce_linear_cross_entropy
[rank0]:     return linear_cross_entropy_apply(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/cut_cross_entropy/cce.py", line 146, in linear_cross_entropy_apply
[rank0]:     loss = LinearCrossEntropyFunction.apply(e, c, bias, params)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/cut_cross_entropy/cce.py", line 67, in forward
[rank0]:     neg_dot = indexed_neg_dot_forward_kernel(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/cut_cross_entropy/indexed_dot.py", line 128, in indexed_neg_dot_forward_kernel
[rank0]:     _indexed_neg_dot_forward_kernel[grid](
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/triton/runtime/jit.py", line 330, in <lambda>
[rank0]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
[rank0]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 385, in run
[rank0]:     return self.fn.run(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 385, in run
[rank0]:     return self.fn.run(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/triton/runtime/jit.py", line 653, in run
[rank0]:     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
[rank0]:     ^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/triton/compiler/compiler.py", line 395, in __getattribute__
[rank0]:     self._init_handles()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/triton/compiler/compiler.py", line 390, in _init_handles
[rank0]:     self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
[rank0]:                                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Config

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00004'
  model_type: LLAMA3
  output_dir: ${output_dir}
  recipe_checkpoint: null
  checkpoint_dir: inputs/model
clip_grad_norm: null
compile: true
cudnn_deterministic_mode: false
custom_sharded_layers: []
dataset:
  _component_: torchtune.datasets.chat_dataset
  source: parquet
  conversation_column: messages
  conversation_style: openai
  split: train
  packed: true
  train_on_input: true
  data_dir: inputs/dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: cut_cross_entropy.LinearCrossEntropy
max_steps_per_epoch: 20
metric_logger:
  _component_: torchtune.training.metric_logging.MLFlowLogger
optimizer:
  _component_: torchao.prototype.low_bit_optim.AdamW8bit
  lr: 5.0e-05
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tokenizer:
  max_seq_len: 2048
  path: inputs/model/original/tokenizer.model
  _component_: torchtune.models.llama3.llama3_tokenizer
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

This may be a side-effect that I don't understand of being in a multi-device setting. I recall that unsloth has had issues related to memory in a DDP/FSDP multi-GPU setting which I believe are related to triton (although it puzzles me why this would be the case given the data parallel part).

@Andrei-Aksionov
Copy link
Author

Hey @nathan-az
Thanks for checking it.
In the PR I only proposed changes to a single device regime, since I have access to a single device only.
So it's great to know what are implications with multi-device setting.

From the error log I can see that the issue is in the indexed_neg_dot_forward_kernel function, which is a part of a custom Triton kernel.

Could you change the config to the following

loss:
    _component_: cut_cross_entropy.LinearCrossEntropy
    impl: torch_compile

and try it one more time?

With this implementation, a special function of only PyTorch code wrapped in torch.compile will be used. So it has higher chances of working properly.

@nathan-az
Copy link
Contributor

nathan-az commented Mar 26, 2025

@Andrei-Aksionov I was able to make a pretty minimal repro so created apple/ml-cross-entropy#31. I also tried with compiling the loss but there was no change. Maybe it'll be obvious to the maintainers what's going wrong!

The above even fails with --nproc_per_node=1 (so either my repro is very wrong, or something funky happens with FSDP). This also means you can test with single gpu :)

I even wonder if this is a PyTorch Core issue, caused by some funky interaction between FSDP and triton.

EDIT: In the repro, this is "fixed" by calling full_tensor on the output module, i.e. output.weight.full_tensor(). I'm not sure if this breaks anything or if there is a cleaner pattern (or if the repro is not a good parallel). Would love for an FSDP SME to weigh in 😓

@felipemello1
Copy link
Contributor

EDIT: In the repro, this is "fixed" by calling full_tensor on the output module, i.e. output.weight.full_tensor()

@nathan-az , could you try setting reshard_after_forward=False? If you could create a simple repro with just the loss and some dummy tensor, that would make it easier to ping someone from torch distributed. Thanks for investigating!

@Andrei-Aksionov
Copy link
Author

Thanks @nathan-az for investigating this 🎉
The only thing to note is that

loss:
    _component_: cut_cross_entropy.LinearCrossEntropy
    impl: torch_compile

is basically

loss = LinearCrossEntropy(imlp="torch_compile")

and not

loss = torch.compile(LinearCrossEntropy)

The difference is that the default implementation uses custom Triton kernels, while torch_compile implementation just fuses logits calculation and loss calculation in a single function wrapped in torch.compile decorator.

I believe we are more interested in this implementation to work, since it is more aligned with the goal of the repo to provide PyTorch native code. In addition, it will make it easier to swap a loss function if needed.

@felipemello1
Copy link
Contributor

felipemello1 commented Mar 26, 2025

@Andrei-Aksionov , thanks again for all of the work on this! I am trying to think about the requirements to land it, but i think it will need some refactoring on torchtune side.

requirements:
0. We should NOT need to add a new dependency (lets use pure pytorch maybe and avoid triton?), without violating licenses.

  1. Recipes cannot be bloated with if/else. The complexity has to be mitigated by maybe checking if the loss has a method, e.g. "is_chunked" or "is_fused", or needs to be handled by some utility.

  2. The model definition cannot be changed, e.g. model.output = new_output. This makes it harder to debug, because you read the code and its one thing, but its executing another.

  3. It should work with distributed

Let me refactor some things today and it should make it easier for you to land your PR. Do you mind working on items 0/3, i.e. making sure that we don't add new dependencies, if possible, and it works with distributed?

@nathan-az
Copy link
Contributor

@Andrei-Aksionov using the torch_compile impl didn't correct the issue unfortunately, although it did produce a more useful error, explicitly stating that the issue was from mixing DTensor and Tensor.

@felipemello1 neither did disabling reshard, although this was a good suggestion that I expected to work. However, I believe this would cause the entire model to accumulate in memory during the forward pass (correct me if I'm wrong on how that arg works), which seems counter to the benefit of the merged kernel.

On the license note - the Liger Kernel repo also has a linear cross entropy implementation. Their repo is BSD-2 so very permissive. I'm not sure if there are any performance difference. From what I've seen on the Issue I posted, the Apple CCE maintainers are very responsive and super helpful, but the Liger implementation could be a backup.

Best of luck getting this going! Apologies for raising the hassle in the distributed case 😅

@felipemello1
Copy link
Contributor

hey @Andrei-Aksionov @nathan-az , i proposed a refactor here: #2531

@Andrei-Aksionov , this should help with trying liger or cut CE. I wonder if with this setup the user can directly instantiate the loss and use it, or if we need some sort of LossAdapter to avoid changing the recipe and all_gather the weights / redirect the inputs (eg weight,outputs.label) --> (output, weight, label)

In the config it could be something like:

loss:
    _component: torchtune.LossAdapter
    module: Liger.my_function
    new_order: [output, weight, label]

But I am afraid that this may be an overkill and would rather not go down that path. Would prefer to have a torchtune version and/or work with the compile team to match the CE perf by doing torch.compile(loss)

btw: someone used liger with tune and wrote a blog https://pytorch.org/blog/peak-performance-minimized-memory/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants