Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect dtype in LayerNormLinear #483

Merged
merged 2 commits into from
Oct 20, 2023
Merged

Conversation

timmoon10
Copy link
Collaborator

We've encountered a runtime error in LayerNormLinear when training LLaMa since RMSNorm is outputting to a buffer with an invalid dtype. In particular, we are not properly handling the case where the RMSNorm output is returned in bf16.

Note that LayerNormMLP handles this correctly:

ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype

@timmoon10 timmoon10 added the bug Something isn't working label Oct 19, 2023
@timmoon10 timmoon10 requested review from ptrendx and ksivaman October 19, 2023 18:04
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request Oct 19, 2023
@ptrendx
Copy link
Member

ptrendx commented Oct 19, 2023

I opened #485 to fix the errors in the fused attention test.

@ksivaman
Copy link
Member

/te-ci pytorch

@ksivaman ksivaman merged commit 1afb625 into NVIDIA:main Oct 20, 2023
denera pushed a commit to denera/TransformerEngine that referenced this pull request Oct 23, 2023
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
ptrendx pushed a commit that referenced this pull request Oct 23, 2023
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
mingxu1067 pushed a commit to mingxu1067/TransformerEngine that referenced this pull request Nov 3, 2023
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
cyanguwa pushed a commit to cyanguwa/TransformerEngine that referenced this pull request Nov 13, 2023
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@timmoon10 timmoon10 deleted the debug-llama branch November 15, 2023 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants