Skip to content

Refactor losses instantiation and chunked CE #2531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Mar 27, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

IMPORTANT: Recipes do NOT work with older version of ChunkedCrosEntropy anymore, because we dont expect transformer to chunk the outputs.

Problem:

  1. We have seen many chunked losses being added to torchtune. The current setup put the chunking burden on the model.
  2. Users have interest in using losses that require model.output.weight as input, e.g. liger losses

Solution:

  1. Enable the recipe to call loss(weight, input, targets)
  2. Reimplement ChunkedCE, so that chunking and projection happens in the loss.
  3. Adds protocol so that new losses can follow the same pattern

PROFILING: https://drive.google.com/drive/folders/1jHOCuOF74F9lmmJv7wxbcK-i_wtB2stf?usp=sharing

Changelog

  • Updated full_distributed and lora_distributed
  • Tested with lora llama 3.2 distributed (TiedLinear)
  • Implemented new ChunkedCE

TODO: when approved, will implement it to the other recipes/losses/update configs

Test

ChunkedCrossEntropyLoss

tune run --nproc_per_node 2 lora_finetune_distributed --config /data/users/felipemello/torchtune/recipes/configs/llama3_2/1B_lora.yaml \
metric_logger=torchtune.training.metric_logging.WandBLogger \
dataset.packed=True \
dataset.split=train[:50%] \
tokenizer.max_seq_len=4096 \
gradient_accumulation_steps=1 \
batch_size=4 \
max_steps_per_epoch=20 \
compile=True \
use_output_weight_in_loss=True \
loss=torchtune.modules.loss.sft_losses.ChunkedCrossEntropyLoss

image

To reproduce

fork ----> https://github.com/pytorch/torchtune
git clone https://github.com/<YOUR_GITHUB_USER>/torchtune.git

cd torchtune
conda create -n torchtune python=3.11
conda activate torchtune
pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu124
pip install -e ".[dev]"
pre-commit install

git remote add felipemello1 https://github.com/felipemello1/torchtune.git
git checkout -b loss_refactor felipemello1/loss_refactor

Copy link

pytorch-bot bot commented Mar 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2531

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 27, 2025
@felipemello1 felipemello1 mentioned this pull request Mar 27, 2025
13 tasks
@felipemello1 felipemello1 changed the title Refactor losses installation and chunked CE Refactor losses instantiation and chunked CE Mar 27, 2025
@felipemello1 felipemello1 marked this pull request as draft March 31, 2025 14:39
@felipemello1 felipemello1 marked this pull request as ready for review March 31, 2025 22:16
return total_loss / total_elements


class ChunkedCrossEntropywithAutograd(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you want to add these Autograd versions? How does this help you test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this version is based off horace's code from a few months back. In this implementation, the chunks are not held up in memory. He coded it to show that you dont need trition.

I dont want to keep it in torchtune, because it would be hard to use to for KD/RL losses. This is more a reference for the compile folks. They are working on enabling the chunking on compile to match the autograd memory perf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without autograd
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with autograd

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a comment in the code to that effect?

@SalmanMohammadi
Copy link
Collaborator

@felipemello1 this is awesome. Out of curioisity did you happen to benchmark against the existing CEWithChunkedOutputLoss?

I wonder if we could simplify the configuration further by removing the need for the user to also specify use_output_weight_in_loss? Could we define a

class BaseLoss(Protocol):
    is_chunked: bool

and do

-                if self.use_output_weight_in_loss:
+                if self.loss_fn.is_chunked:
                    weight = self._model.get_output_weight()
                    current_loss = self._loss_fn(weight, outputs, labels)
                else:
                    labels = labels.reshape(-1)
                    logits = logits.reshape(-1, logits.size(-1))
                    outputs = outputs.reshape(-1, outputs.size(-1))
                    current_loss = self._loss_fn(outputs, labels)

It would require either 1) requiring that all losses use this protocol (which tbh I wouldn't be opposed to as we start to support more custom losses without needing to modify recipes), or doing a hasattr check on self._loss_fn and relying on an identifying field on just the chunked losses.

wdyt?

@felipemello1
Copy link
Contributor Author

felipemello1 commented Apr 4, 2025

I wonder if we could simplify the configuration further by removing the need for the user to also specify use_output_weight_in_loss?

@SalmanMohammadi , i thought about it and even implemented, but then realized that it would be hard to support 3rd party libraries, unless we create some sort of loss adapter, which we may need to do anyway, because not all libraries follow the patten (weight, input, label). They may follow (label, weight, input), for example.

the loss adapter could be something like:

config.yaml

loss:
	_component_: torchtune.loss.lossadapter
   loss: path.to.loss
   requires_weight_input: True
   input_order: ["label", "weight", "input"]

@SalmanMohammadi
Copy link
Collaborator

I wonder if we could simplify the configuration further by removing the need for the user to also specify use_output_weight_in_loss?

@SalmanMohammadi , i thought about it and even implemented, but then realized that it would be hard to support 3rd party libraries, unless we create some sort of loss adapter, which we may need to do anyway, because not all libraries follow the patten (weight, input, label). They may follow (label, weight, input), for example.

the loss adapter could be something like:

config.yaml

loss:
	_component_: torchtune.loss.lossadapter
   loss: path.to.loss
   requires_weight_input: True
   input_order: ["label", "weight", "input"]

I can't think of any 3rd party losses which we claim to support which would fall into this category - do you have any examples? I would say that having a stricter contract about which losses we do support would make interoperability more straightforward - i.e. a user would know exactly how to define a

class MyTorchtuneLoss(TorchtuneLossProtocol):
    ...
    self.loss = ThirdPartyLossChunkedLoss(...)
    self.is_chunked = True

    def forward(weight, input, label):
        return self.loss(input, weight, label)

If we're getting too in the weeds here I'm happy with how you've implemented it in this PR and leaving this discussion as a follow up : )

@felipemello1
Copy link
Contributor Author

felipemello1 commented Apr 4, 2025

I can't think of any 3rd party losses which we claim to support which would fall into this category - do you have any examples?

yes, liger and apple:
https://github.com/linkedin/Liger-Kernel/tree/main/src/liger_kernel/chunked_loss
https://github.com/apple/ml-cross-entropy/tree/main/cut_cross_entropy

If we're getting too in the weeds here I'm happy with how you've implemented it in this PR and leaving this discussion as a follow up : )

I think that the time is now, so we dont have to refactor it again :P

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock but there's a few things related to consistency and documentation that should be cleaned up.

import torch


class SFTLossWithProjection(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably needs to be SFTLossWithOutputProj or something. Projection is too vague.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this name is confusing. I think we should just standardize on "fused" or "linear", or "chunked". All the names have issues which we've discussed but if we're consistent at least people should be able to learn the term quickly.

from .loss_protocols import SFTLossWithProjection


class ChunkedCrossEntropyLoss(nn.Module, SFTLossWithProjection):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the docs? Also maybe include a slightly longer description of why we might want to use this. And how we might use this in a generic training loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
class ChunkedCrossEntropyLoss(nn.Module, SFTLossWithProjection):
class MegaProjChunkyLossinator(nn.Module, SFTLossWithProjection):

@@ -114,12 +114,6 @@ def trace_handler(
# Memory timeline sometimes fails to export
if prof.profile_memory and torch.cuda.is_available():
if rank == 0:
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

@@ -26,4 +26,6 @@
"get_torch_device_namespace",
"DeviceSupport",
"log_rank_zero",
"deprecated",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these in the docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. I forgot to check

# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
# The loss may handle the output projection. If true, the model should skip it.
self.use_output_weight_in_loss = getattr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output weight & output proj? Should stay consistent, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should

@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our protocols for tokenizers live in tokenizers/_utils.py. Do you think it's worth keeping things consistent and renaming this to _utils.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i could move it, but it doesnt feel intuitive that protocols are in utils. Do you think its a better choice, or is it just to keep things consistent?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I do like protocols better. Since we're already in the loss module maybe just protocols.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ this

def apply_compile_strategy(self, *args, **kwargs):
"""Torch compiles the loss function. Can be useful when greater control is needed,
for example when only compiling a portion of the loss calculation."""
self.forward = torch.compile(self.forward, *args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be doing this here? I'd vote to add this into the docstring as an example


use_output_proj_in_loss: bool = False

def apply_compile_strategy(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment to above

Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` or a list of layer
output tensors defined by ``output_hidden_states`` with the
final output tensor appended to the list.
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_projection=False`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_projection=False`
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if ``self.skip_output_projection=False``

my bad


use_output_proj_in_loss: bool = True

def apply_compile_strategy(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on naming this def compile? Is that too vague?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe it would override the module.compile method. We probably dont want it.

Comment on lines +60 to +61
outputs (torch.Tensor): Logits of the model. Shape [bsz, seq_len, vocab_size]
targets (torch.Tensor): Labels for the model. Shape [bsz, seq_len]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outputs (torch.Tensor): Logits of the model. Shape [bsz, seq_len, vocab_size]
targets (torch.Tensor): Labels for the model. Shape [bsz, seq_len]
outputs (torch.Tensor): Logits of the model. Shape ``[bsz, seq_len, vocab_size]``
targets (torch.Tensor): Labels for the model. Shape ``[bsz, seq_len]``

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):

if self.use_output_weight_in_loss:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice

# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
# The loss may handle the output projection. If true, the model should skip it.
self.use_output_weight_in_loss = getattr(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tangential point: if the contract is that SFT losses follow the protocols defined in loss_protocols, do we need to make this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

someone may try to use a loss that is not from torchtune, e.g. vanilla F.cross_entropy_loss



class SFTLoss(Protocol):
"""Protocol for loss functions in torchtune used in sft recipes."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Protocol for loss functions in torchtune used in sft recipes."""
"""Protocol for loss functions in torchtune used in SFT recipes."""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if i like "SFT" here, since it may not be obvious for a new reader what it means



class SFTLossWithProjection(Protocol):
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require
"""Protocol for loss functions in torchtune used in SFT recipes and that require

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer "SFTI dont know if i like "SFT" here, since it may not be obvious for a new reader what it means

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

real nice

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this big effort. This looks good and I'm happy to approve it now. Please finish going through and resolving the open comments before landing.

Comment on lines 347 to 349
# skip final projection, since the loss takes hidden input instead of logits
self.skip_unembedding = cfg.get("loss_takes_embeddings", False)
self._model.set_skip_unembedding(self.skip_unembedding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: skip_output_layer

import torch


class SFTLossWithProjection(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this name is confusing. I think we should just standardize on "fused" or "linear", or "chunked". All the names have issues which we've discussed but if we're consistent at least people should be able to learn the term quickly.

target_chunks[idx],
)

return total_loss / total_elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it'd be nice to offer the same 'reduction' option as most pytorch losses to control returning the mean, sum, or no reduction

@@ -301,9 +301,12 @@ def setup(self, cfg: DictConfig) -> None:
if self._compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan for rolling this out to the other sft recipes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Recipes NOT being updated should still work with configs NOT being updated
  2. Recipes being updated should NOT work anymore with old ce_with_chunked_outputs_loss
  3. So any recipe that is changed also requires the configs to be updated with the new loss

TODO: need to check if the deprecation warnings work fine. This can be checked by running a recipe/config that has not been updated.

@@ -396,6 +400,7 @@ def __init__(
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0
self._skip_output_projection = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should enforce in init that the output module has the "weight" property

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants