-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Replace views with reshapes and update PyTorch autocast API #1250
Conversation
90f2f7e
to
f22e963
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I agree that reshape
is preferable to view
. sed 's/view/reshape/g'
is a little overkill though.
data=tensor._data.reshape(*shape), | ||
) | ||
return tensor.view(*shape) | ||
return tensor.reshape(*shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is for Float8Tensor.view
, so using reshape
could cause correctness problems:
data=tensor._data.reshape(*shape), | |
) | |
return tensor.view(*shape) | |
return tensor.reshape(*shape) | |
data=tensor._data.view(*shape), | |
) | |
return tensor.view(*shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I have made the changes for Float8Tensor.
93ba06e
to
2ac3da8
Compare
2ac3da8
to
9001081
Compare
873637c
to
617e1de
Compare
if isinstance(grad, Float8Tensor): | ||
dgrad = Float8Tensor.make_like( | ||
grad, | ||
data=grad._data.reshape(ctx.shape), | ||
data=grad._data.view(ctx.shape), | ||
) | ||
return dgrad, None | ||
return grad.reshape(ctx.shape), None | ||
return grad.view(ctx.shape), None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is for Float8Tensor.reshape
:
if isinstance(grad, Float8Tensor): | |
dgrad = Float8Tensor.make_like( | |
grad, | |
data=grad._data.reshape(ctx.shape), | |
data=grad._data.view(ctx.shape), | |
) | |
return dgrad, None | |
return grad.reshape(ctx.shape), None | |
return grad.view(ctx.shape), None | |
if isinstance(grad, Float8Tensor): | |
dgrad = Float8Tensor.make_like( | |
grad, | |
data=grad._data.reshape(ctx.shape), | |
) | |
return dgrad, None | |
return grad.reshape(ctx.shape), None |
if isinstance(tensor, Float8Tensor): | ||
return Float8Tensor.make_like( | ||
tensor, | ||
data=tensor._data.reshape(*shape), | ||
data=tensor._data.view(*shape), | ||
) | ||
return tensor.reshape(*shape) | ||
return tensor.view(*shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is for Float8Tensor.reshape
:
if isinstance(tensor, Float8Tensor): | |
return Float8Tensor.make_like( | |
tensor, | |
data=tensor._data.reshape(*shape), | |
data=tensor._data.view(*shape), | |
) | |
return tensor.reshape(*shape) | |
return tensor.view(*shape) | |
if isinstance(tensor, Float8Tensor): | |
return Float8Tensor.make_like( | |
tensor, | |
data=tensor._data.reshape(*shape), | |
) | |
return tensor.reshape(*shape) |
274d4e3
to
99313bf
Compare
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. reshape
is generally preferable to view
, although many of these changes are redundant since the tensors are already contiguous.
Please sign your commits to pass the DCO check, and we'll merge if there are no concerning test failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to transformer_engine/pytorch/attention.py
have snuck back in. It can be reverted with:
git checkout main -- transformer_engine/pytorch/attention.py
git commit -m "Revert changes to transformer_engine/pytorch/attention.py" --signoff
The commit history is mangled and there are still some unsigned commits. At this point it's better to squash your commits:
# Merge main branch
git checkout main
git pull origin main
git checkout fix_layernorm_fsdp
git merge main
# Squash all changes into a single commit
git reset main
git commit -a -m 'Use reshape instead of view and update PyTorch autocast API' --signoff
# Force-push to your GitHub
git push eljandoubi fix_layernorm_fsdp --force
Just in case, I've made a copy of your branch: https://github.com/timmoon10/TransformerEngine/tree/eljandoubi/fix_layernorm_fsdp
5cb5e2e
to
b53d398
Compare
Signed-off-by: eljandoubi <[email protected]>
b53d398
to
8ce50b3
Compare
The commit history was still mangled (original branch is at https://github.com/timmoon10/TransformerEngine/tree/eljandoubi/fix_layernorm_fsdp-20241018), so I've manually squashed the commit. However, bugs have crept back in (#1250 (comment), #1250 (comment), #1250 (comment)). At the moment I think this PR is riskier and more of a hassle than it is worth it. If this is important for your use-case, I suggest breaking this up into two more manageable PRs:
|
Description
Update torch.get_autocast_gpu_dtype to torch.get_autocast_dtype("cuda") for all the pytorch.
Migrate from torch.view to torch.reshape. Works with non-contiguous tensors, can copy data. It is handful for distributed training.
@timmoon10 @ksivaman @ptrendx @cyanguwa
Fixes #1247
Type of change