-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Integration test for Megatron-LM #1329
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Pipeline 20338114 |
Pipeline 20444324 is green |
if requires_grad != x.requires_grad: | ||
if requires_grad: | ||
x.requires_grad_() | ||
else: | ||
x = x.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixes a te.Sequential bug that was exposed by Mcore. When running in eval mode, we want x.requires_grad=False
so that the op knows that it doesn't need to prepare for that grad. However, PyTorch sometimes complains if you change a tensor's requires_grad
from True
to False
(i.e. when the tensor is not a leaf in the autograd graph). Detaching the tensor works around this case.
@@ -184,17 +185,21 @@ def op_forward( | |||
) -> torch.Tensor: | |||
|
|||
# Check tensor dims | |||
weight = self.weight | |||
weight_dims = tuple(weight.size()) | |||
input_dims = tuple(input_.size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apparently torch.Size
is a subclass of tuple
so tuple creation probably not needed
@@ -165,17 +168,21 @@ def op_forward( | |||
) -> torch.Tensor: | |||
|
|||
# Check tensor dims | |||
weight = self.weight | |||
weight_dims = tuple(weight.size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to tuple
ize
Signed-off-by: Tim Moon <[email protected]>
* Handle deprecated `hidden_size` arg in norm modules Signed-off-by: Tim Moon <[email protected]> * Support initializing norm ops on CPU Signed-off-by: Tim Moon <[email protected]> * Add integration test for Megatron-LM Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename Mcore integration test Signed-off-by: Tim Moon <[email protected]> * Handle case in RMSNorm where hidden dim is not provided Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
#1033 broke Megatron-LM's wrappers for the LayerNorm and RMSNorm modules:
hidden_size
arg tonormalized_shape
in order to matchtorch.nn.LayerNorm
, but Megatron-LM treatshidden_size
as a kwarg:https://github.com/NVIDIA/Megatron-LM/blob/aded519cfb1de2abf96f36ca059f992294b7876f/megatron/core/extensions/transformer_engine.py#L65.
This PR adds logic to handle the
hidden_size
arg and print a deprecation warning.To help detect these issues in the future, I've also added an integration test that runs Megatron-LM to train a very small GPT model.
Type of change
Changes
hidden_size
arg in LayerNorm and RMSNorm modulesChecklist: