Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Integration test for Megatron-LM #1329

Merged
merged 7 commits into from
Nov 21, 2024

Conversation

timmoon10
Copy link
Collaborator

Description

#1033 broke Megatron-LM's wrappers for the LayerNorm and RMSNorm modules:

To help detect these issues in the future, I've also added an integration test that runs Megatron-LM to train a very small GPT model.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor
  • Testing

Changes

  • Handle deprecated hidden_size arg in LayerNorm and RMSNorm modules
  • Allow LayerNorm and RMSNorm operations to be initialized on CPU
  • Add Megatron-LM integration test

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the bug Something isn't working label Nov 13, 2024
@timmoon10 timmoon10 requested a review from ksivaman November 13, 2024 00:00
@timmoon10 timmoon10 mentioned this pull request Nov 13, 2024
13 tasks
@timmoon10
Copy link
Collaborator Author

Pipeline 20338114

@timmoon10
Copy link
Collaborator Author

timmoon10 commented Nov 15, 2024

Pipeline 20444324 is green

Comment on lines +138 to +142
if requires_grad != x.requires_grad:
if requires_grad:
x.requires_grad_()
else:
x = x.detach()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a te.Sequential bug that was exposed by Mcore. When running in eval mode, we want x.requires_grad=False so that the op knows that it doesn't need to prepare for that grad. However, PyTorch sometimes complains if you change a tensor's requires_grad from True to False (i.e. when the tensor is not a leaf in the autograd graph). Detaching the tensor works around this case.

transformer_engine/pytorch/module/rmsnorm.py Show resolved Hide resolved
@@ -184,17 +185,21 @@ def op_forward(
) -> torch.Tensor:

# Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
input_dims = tuple(input_.size())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently torch.Size is a subclass of tuple so tuple creation probably not needed

@@ -165,17 +168,21 @@ def op_forward(
) -> torch.Tensor:

# Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to tupleize

@timmoon10 timmoon10 merged commit 6b98768 into NVIDIA:main Nov 21, 2024
14 checks passed
timmoon10 added a commit that referenced this pull request Nov 21, 2024
* Handle deprecated `hidden_size` arg in norm modules

Signed-off-by: Tim Moon <[email protected]>

* Support initializing norm ops on CPU

Signed-off-by: Tim Moon <[email protected]>

* Add integration test for Megatron-LM

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rename Mcore integration test

Signed-off-by: Tim Moon <[email protected]>

* Handle case in RMSNorm where hidden dim is not provided

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1.13.0 bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants