-
Notifications
You must be signed in to change notification settings - Fork 582
Refactor losses instantiation and chunked CE #2531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2531
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/modules/loss/sft_losses.py
Outdated
return total_loss / total_elements | ||
|
||
|
||
class ChunkedCrossEntropywithAutograd(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you want to add these Autograd versions? How does this help you test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this version is based off horace's code from a few months back. In this implementation, the chunks are not held up in memory. He coded it to show that you dont need trition.
I dont want to keep it in torchtune, because it would be hard to use to for KD/RL losses. This is more a reference for the compile folks. They are working on enabling the chunking on compile to match the autograd memory perf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put a comment in the code to that effect?
@felipemello1 this is awesome. Out of curioisity did you happen to benchmark against the existing I wonder if we could simplify the configuration further by removing the need for the user to also specify class BaseLoss(Protocol):
is_chunked: bool and do - if self.use_output_weight_in_loss:
+ if self.loss_fn.is_chunked:
weight = self._model.get_output_weight()
current_loss = self._loss_fn(weight, outputs, labels)
else:
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
outputs = outputs.reshape(-1, outputs.size(-1))
current_loss = self._loss_fn(outputs, labels) It would require either 1) requiring that all losses use this protocol (which tbh I wouldn't be opposed to as we start to support more custom losses without needing to modify recipes), or doing a wdyt? |
@SalmanMohammadi , i thought about it and even implemented, but then realized that it would be hard to support 3rd party libraries, unless we create some sort of loss adapter, which we may need to do anyway, because not all libraries follow the patten (weight, input, label). They may follow (label, weight, input), for example. the loss adapter could be something like: config.yaml
|
Co-authored-by: salman <[email protected]>
I can't think of any 3rd party losses which we claim to support which would fall into this category - do you have any examples? I would say that having a stricter contract about which losses we do support would make interoperability more straightforward - i.e. a user would know exactly how to define a
If we're getting too in the weeds here I'm happy with how you've implemented it in this PR and leaving this discussion as a follow up : ) |
yes, liger and apple:
I think that the time is now, so we dont have to refactor it again :P |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock but there's a few things related to consistency and documentation that should be cleaned up.
import torch | ||
|
||
|
||
class SFTLossWithProjection(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably needs to be SFTLossWithOutputProj
or something. Projection is too vague.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this name is confusing. I think we should just standardize on "fused" or "linear", or "chunked". All the names have issues which we've discussed but if we're consistent at least people should be able to learn the term quickly.
from .loss_protocols import SFTLossWithProjection | ||
|
||
|
||
class ChunkedCrossEntropyLoss(nn.Module, SFTLossWithProjection): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this to the docs? Also maybe include a slightly longer description of why we might want to use this. And how we might use this in a generic training loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
class ChunkedCrossEntropyLoss(nn.Module, SFTLossWithProjection): | |
class MegaProjChunkyLossinator(nn.Module, SFTLossWithProjection): |
@@ -114,12 +114,6 @@ def trace_handler( | |||
# Memory timeline sometimes fails to export | |||
if prof.profile_memory and torch.cuda.is_available(): | |||
if rank == 0: | |||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this removed?
@@ -26,4 +26,6 @@ | |||
"get_torch_device_namespace", | |||
"DeviceSupport", | |||
"log_rank_zero", | |||
"deprecated", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these in the docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. I forgot to check
# set num_output_chunks for model | ||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) | ||
# The loss may handle the output projection. If true, the model should skip it. | ||
self.use_output_weight_in_loss = getattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output weight & output proj? Should stay consistent, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should
@@ -0,0 +1,67 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our protocols for tokenizers live in tokenizers/_utils.py
. Do you think it's worth keeping things consistent and renaming this to _utils.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i could move it, but it doesnt feel intuitive that protocols are in utils. Do you think its a better choice, or is it just to keep things consistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I do like protocols better. Since we're already in the loss
module maybe just protocols.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ this
def apply_compile_strategy(self, *args, **kwargs): | ||
"""Torch compiles the loss function. Can be useful when greater control is needed, | ||
for example when only compiling a portion of the loss calculation.""" | ||
self.forward = torch.compile(self.forward, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be doing this here? I'd vote to add this into the docstring as an example
|
||
use_output_proj_in_loss: bool = False | ||
|
||
def apply_compile_strategy(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment to above
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` or a list of layer | ||
output tensors defined by ``output_hidden_states`` with the | ||
final output tensor appended to the list. | ||
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_projection=False` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if `self.skip_output_projection=False` | |
Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape ``[b x s x v]`` if ``self.skip_output_projection=False`` |
my bad
|
||
use_output_proj_in_loss: bool = True | ||
|
||
def apply_compile_strategy(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on naming this def compile
? Is that too vague?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i believe it would override the module.compile method. We probably dont want it.
outputs (torch.Tensor): Logits of the model. Shape [bsz, seq_len, vocab_size] | ||
targets (torch.Tensor): Labels for the model. Shape [bsz, seq_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputs (torch.Tensor): Logits of the model. Shape [bsz, seq_len, vocab_size] | |
targets (torch.Tensor): Labels for the model. Shape [bsz, seq_len] | |
outputs (torch.Tensor): Logits of the model. Shape ``[bsz, seq_len, vocab_size]`` | |
targets (torch.Tensor): Labels for the model. Shape ``[bsz, seq_len]`` |
# Shift labels to compute loss | ||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] | ||
# But this way we dont need to slice the logits. We just add an ignore index to labels. | ||
labels = torch.hstack( | ||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) | ||
) | ||
if not isinstance(logits, list): | ||
|
||
if self.use_output_weight_in_loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice
# set num_output_chunks for model | ||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) | ||
# The loss may handle the output projection. If true, the model should skip it. | ||
self.use_output_weight_in_loss = getattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tangential point: if the contract is that SFT losses follow the protocols defined in loss_protocols
, do we need to make this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
someone may try to use a loss that is not from torchtune, e.g. vanilla F.cross_entropy_loss
|
||
|
||
class SFTLoss(Protocol): | ||
"""Protocol for loss functions in torchtune used in sft recipes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Protocol for loss functions in torchtune used in sft recipes.""" | |
"""Protocol for loss functions in torchtune used in SFT recipes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if i like "SFT" here, since it may not be obvious for a new reader what it means
|
||
|
||
class SFTLossWithProjection(Protocol): | ||
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Protocol for loss functions in torchtune used in Supervised Finetune recipes and that require | |
"""Protocol for loss functions in torchtune used in SFT recipes and that require |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer "SFTI dont know if i like "SFT" here, since it may not be obvious for a new reader what it means
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
real nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this big effort. This looks good and I'm happy to approve it now. Please finish going through and resolving the open comments before landing.
recipes/full_finetune_distributed.py
Outdated
# skip final projection, since the loss takes hidden input instead of logits | ||
self.skip_unembedding = cfg.get("loss_takes_embeddings", False) | ||
self._model.set_skip_unembedding(self.skip_unembedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: skip_output_layer
import torch | ||
|
||
|
||
class SFTLossWithProjection(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this name is confusing. I think we should just standardize on "fused" or "linear", or "chunked". All the names have issues which we've discussed but if we're consistent at least people should be able to learn the term quickly.
target_chunks[idx], | ||
) | ||
|
||
return total_loss / total_elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it'd be nice to offer the same 'reduction' option as most pytorch losses to control returning the mean, sum, or no reduction
@@ -301,9 +301,12 @@ def setup(self, cfg: DictConfig) -> None: | |||
if self._compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the plan for rolling this out to the other sft recipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Recipes NOT being updated should still work with configs NOT being updated
- Recipes being updated should NOT work anymore with old ce_with_chunked_outputs_loss
- So any recipe that is changed also requires the configs to be updated with the new loss
TODO: need to check if the deprecation warnings work fine. This can be checked by running a recipe/config that has not been updated.
@@ -396,6 +400,7 @@ def __init__( | |||
self.head_dim = head_dim | |||
self.causal_mask = None | |||
self.num_output_chunks = 0 | |||
self._skip_output_projection = False | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should enforce in init that the output module has the "weight" property
Co-authored-by: salman <[email protected]>
Co-authored-by: salman <[email protected]>
Context
What is the purpose of this PR? Is it to
IMPORTANT: Recipes do NOT work with older version of ChunkedCrosEntropy anymore, because we dont expect transformer to chunk the outputs.
Problem:
Solution:
PROFILING: https://drive.google.com/drive/folders/1jHOCuOF74F9lmmJv7wxbcK-i_wtB2stf?usp=sharing
Changelog
TODO: when approved, will implement it to the other recipes/losses/update configs
Test
ChunkedCrossEntropyLoss
To reproduce