diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh new file mode 100644 index 0000000000..01c9e14eb1 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Paths +: ${TE_PATH:=/opt/transformerengine} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} + +# Download Megatron-LM if needed +if [ ! -d "${MCORE_PATH}" ]; then + pushd $(dirname ${MCORE_PATH}) + git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + popd +fi + +# Megatron-LM invocation +COMMAND=" +NVTE_TORCH_COMPILE=0 +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +NVTE_FLASH_ATTN=1 +NVTE_FWD_LAYERNORM_SM_MARGIN=0 +NVTE_BWD_LAYERNORM_SM_MARGIN=0 +CUDA_DEVICE_MAX_CONNECTIONS=1 +NVTE_BIAS_GELU_NVFUSION=0 +NVTE_BIAS_DROPOUT_FUSION=0 + +python +-m torch.distributed.launch +--use_env +--nnodes=1 +--nproc_per_node=1 + +${MCORE_PATH}/pretrain_gpt.py +--tensor-model-parallel-size 1 +--pipeline-model-parallel-size 1 +--use-cpu-initialization +--num-layers 2 +--hidden-size 128 +--num-attention-heads 8 +--seq-length 128 +--max-position-embeddings 2048 +--micro-batch-size 1 +--global-batch-size 8 +--train-iters 10 +--eval-iters 10 +--lr 1e-4 +--mock-data +--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json +--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt +--transformer-impl transformer_engine +--fp8-format hybrid +" +COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') + +# Launch Megatron-LM +bash -c "${COMMAND}" diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 32142cf48c..b42079d299 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index f3651ecc19..bd7db1f775 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 99c9c493db..710f838581 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -84,28 +89,23 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed dtype = canonicalize_dtype(dtype) weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) bias = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -143,17 +143,18 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight bias = self.bias - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) - if bias.device.type != "cuda": - bias = torch.empty_like(bias, device=self.device) - else: - bias = bias.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) # Initialize values if self.zero_centered_gamma: @@ -184,17 +185,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) @@ -266,6 +271,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, means, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -282,9 +288,12 @@ def op_backward( # Saved tensors from forward pass x, means, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -312,6 +321,6 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) - grad_bias = reshape(db, self._shape) + grad_weight = reshape(dw, weight_dims) + grad_bias = reshape(db, weight_dims) return grad_input, (grad_weight, grad_bias) diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 4f0e2ddc22..84f05ce713 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -83,22 +88,17 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=canonicalize_dtype(dtype), ) weight = torch.nn.Parameter(weight) @@ -133,12 +133,15 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values if self.zero_centered_gamma: @@ -165,17 +168,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) if isinstance(x, QuantizedTensor): @@ -241,6 +248,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -257,9 +265,12 @@ def op_backward( # Saved tensors from forward pass x, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -285,5 +296,5 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) + grad_weight = reshape(dw, weight_dims) return grad_input, (grad_weight,) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 6fcb435e5c..8b2a04cff8 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -135,7 +135,11 @@ def forward( requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx].requires_grad = requires_grad - x.requires_grad_(requires_grad=requires_grad) + if requires_grad != x.requires_grad: + if requires_grad: + x.requires_grad_() + else: + x = x.detach() # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]