Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Replace views with reshapes and update PyTorch autocast API #1250

Closed
wants to merge 1 commit into from

Conversation

eljandoubi
Copy link

@eljandoubi eljandoubi commented Oct 14, 2024

Description

Update torch.get_autocast_gpu_dtype to torch.get_autocast_dtype("cuda") for all the pytorch.
Migrate from torch.view to torch.reshape. Works with non-contiguous tensors, can copy data. It is handful for distributed training.

@timmoon10 @ksivaman @ptrendx @cyanguwa

Fixes #1247

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • Code refractor

@eljandoubi eljandoubi force-pushed the fix_layernorm_fsdp branch 2 times, most recently from 90f2f7e to f22e963 Compare October 15, 2024 16:42
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I agree that reshape is preferable to view. sed 's/view/reshape/g' is a little overkill though.

Comment on lines +237 to +246
data=tensor._data.reshape(*shape),
)
return tensor.view(*shape)
return tensor.reshape(*shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is for Float8Tensor.view, so using reshape could cause correctness problems:

Suggested change
data=tensor._data.reshape(*shape),
)
return tensor.view(*shape)
return tensor.reshape(*shape)
data=tensor._data.view(*shape),
)
return tensor.view(*shape)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I have made the changes for Float8Tensor.

transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
tests/paddle/test_operators.py Outdated Show resolved Hide resolved
@eljandoubi eljandoubi force-pushed the fix_layernorm_fsdp branch 2 times, most recently from 93ba06e to 2ac3da8 Compare October 16, 2024 21:09
@eljandoubi eljandoubi closed this Oct 16, 2024
@eljandoubi eljandoubi deleted the fix_layernorm_fsdp branch October 16, 2024 22:02
@eljandoubi eljandoubi restored the fix_layernorm_fsdp branch October 16, 2024 22:03
@eljandoubi eljandoubi reopened this Oct 16, 2024
Comment on lines 299 to +305
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
return grad.view(ctx.shape), None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is for Float8Tensor.reshape:

Suggested change
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
return grad.view(ctx.shape), None
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None

Comment on lines 285 to +290
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
data=tensor._data.view(*shape),
)
return tensor.reshape(*shape)
return tensor.view(*shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is for Float8Tensor.reshape:

Suggested change
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
data=tensor._data.view(*shape),
)
return tensor.reshape(*shape)
return tensor.view(*shape)
if isinstance(tensor, Float8Tensor):
return Float8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
)
return tensor.reshape(*shape)

transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
@timmoon10 timmoon10 self-requested a review October 17, 2024 19:47
@timmoon10
Copy link
Collaborator

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. reshape is generally preferable to view, although many of these changes are redundant since the tensors are already contiguous.

Please sign your commits to pass the DCO check, and we'll merge if there are no concerning test failures.

@timmoon10 timmoon10 self-requested a review October 18, 2024 01:28
@timmoon10 timmoon10 changed the title Fix layernorm fsdp [PyTorch] Replace views with reshapes and update PyTorch autocast API Oct 18, 2024
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to transformer_engine/pytorch/attention.py have snuck back in. It can be reverted with:

git checkout main -- transformer_engine/pytorch/attention.py
git commit -m "Revert changes to transformer_engine/pytorch/attention.py" --signoff

The commit history is mangled and there are still some unsigned commits. At this point it's better to squash your commits:

# Merge main branch
git checkout main
git pull origin main
git checkout fix_layernorm_fsdp
git merge main

# Squash all changes into a single commit
git reset main
git commit -a -m 'Use reshape instead of view and update PyTorch autocast API' --signoff

# Force-push to your GitHub
git push eljandoubi fix_layernorm_fsdp --force

Just in case, I've made a copy of your branch: https://github.com/timmoon10/TransformerEngine/tree/eljandoubi/fix_layernorm_fsdp

@eljandoubi eljandoubi force-pushed the fix_layernorm_fsdp branch 5 times, most recently from 5cb5e2e to b53d398 Compare October 18, 2024 07:10
@timmoon10
Copy link
Collaborator

timmoon10 commented Oct 18, 2024

The commit history was still mangled (original branch is at https://github.com/timmoon10/TransformerEngine/tree/eljandoubi/fix_layernorm_fsdp-20241018), so I've manually squashed the commit. However, bugs have crept back in (#1250 (comment), #1250 (comment), #1250 (comment)).

At the moment I think this PR is riskier and more of a hassle than it is worth it. If this is important for your use-case, I suggest breaking this up into two more manageable PRs:

  • Update the PyTorch autocast API. This should be a fairly easy and safe change.
  • Thoughtfully replace view with reshape. The current approach (blindly replace all views and changing back to reshape when necessary) was dangerous, mostly redundant, and exposed a large surface area for merge conflicts. It would be much better to do the opposite approach and only replace reshapes that you know are safe and effective.

@timmoon10 timmoon10 closed this Oct 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FSDP with FP8 is not working
2 participants