Skip to content

Commit

Permalink
[PyTorch] Normalization ops (#1033)
Browse files Browse the repository at this point in the history
* Add layer norm op

Signed-off-by: Tim Moon <[email protected]>

* Add FP8 cast op

Signed-off-by: Tim Moon <[email protected]>

* Add tests for linear and layernorm with FP8 output

Signed-off-by: Tim Moon <[email protected]>

* RMSNorm op

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Replace LayerNorm module with LayerNorm op

Signed-off-by: Tim Moon <[email protected]>

* Replace RMSNorm module with RMSNorm op

Signed-off-by: Tim Moon <[email protected]>

* Add AMP support

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Do not save autograd context if grad mode is disabled

Debugging ONNX export tests.

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Forward args in pre_forward func to base op class

Signed-off-by: Tim Moon <[email protected]>

* Update to use QuantizedTensor class

Signed-off-by: Tim Moon <[email protected]>

* Apply suggestions from code review

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Review suggestions from @ptrendx

Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg.

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Use weight dtype as default compute dtype

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
  • Loading branch information
3 people authored Nov 5, 2024
1 parent f20d3dd commit 77c37d4
Show file tree
Hide file tree
Showing 12 changed files with 1,416 additions and 511 deletions.
515 changes: 411 additions & 104 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _load_library():
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers

# Register custom op symbolic ONNX functions
Expand Down
294 changes: 112 additions & 182 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,158 +3,90 @@
# See LICENSE for license information.

"""LayerNorm API"""
import os
import warnings
from typing import Union, Tuple, Optional
from typing import Iterable, Optional, Union

import torch
from torch.nn.parameter import Parameter
from torch.nn import init

import transformer_engine_torch as tex
from ..cpp_extensions import (
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp

__all__ = ["LayerNorm"]


class _LayerNorm(torch.autograd.Function):
"""functional LayerNorm"""

@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
inf_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))

# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)

if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = (
layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma
),
None,
None,
)
return ln_out.view_as(inp)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None

class LayerNorm(_LayerNormOp):
r"""Layer Normalization
class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
:math:`\gamma` and :math:`\beta` are learnable affine transform
parameters that match the inner-most dimensions of the input
tensor.
Parameters
----------
hidden_size : int
size of each input sample.
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration.
"""

def __init__(
self,
hidden_size: int,
normalized_shape: Union[Iterable[int], int],
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
**kwargs,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None

self.reset_parameters(defer_init=device == "meta")
# Handle deprecated options
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
"Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
)
kwargs["dtype"] = params_dtype

# Initialize layer norm operation
super().__init__(
normalized_shape,
eps=eps,
zero_centered_gamma=zero_centered_gamma,
**kwargs,
)

# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
Expand All @@ -164,64 +96,62 @@ def reset_layer_norm_parameters(self) -> None:
DeprecationWarning,
stacklevel=2,
)
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
self.reset_parameters()

def reset_parameters(self, defer_init=False) -> None:
def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Init LayerNorm parameters"""
if defer_init:
return

if self.weight.device == torch.device("meta"):
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda"))
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
init.constant_(self.weight, float(not self.zero_centered_gamma))

if self.bias.device == torch.device("meta"):
self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda"))
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
init.zeros_(self.bias)

@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring

# Set the activation type for AMP.
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype

if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
args = []
else:
fwd_fn = _LayerNorm.forward
args = [None]

args += (
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
)

return fwd_fn(*args)
# Check whether to defer init (deprecated)
if defer_init is not None:
warnings.warn(
"defer_init argument to reset_parameters function is deprecated. Set device to"
' "meta" instead.',
DeprecationWarning,
stacklevel=2,
)
if defer_init:
return

# Reset parameters
super().reset_parameters()

# Set flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel

@property
def fwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["forward"]

@fwd_ln_sm_margin.setter
def fwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["forward"] = val

@property
def bwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["backward"]

@bwd_ln_sm_margin.setter
def bwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["backward"] = val

@property
def inf_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["inference"]

@inf_ln_sm_margin.setter
def inf_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["inference"] = val
Loading

0 comments on commit 77c37d4

Please sign in to comment.