Skip to content

Commit 722a89b

Browse files
timmoon10pre-commit-ci[bot]ptrendx
authored andcommitted
[PyTorch] Normalization ops (NVIDIA#1033)
* Add layer norm op Signed-off-by: Tim Moon <[email protected]> * Add FP8 cast op Signed-off-by: Tim Moon <[email protected]> * Add tests for linear and layernorm with FP8 output Signed-off-by: Tim Moon <[email protected]> * RMSNorm op Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Replace LayerNorm module with LayerNorm op Signed-off-by: Tim Moon <[email protected]> * Replace RMSNorm module with RMSNorm op Signed-off-by: Tim Moon <[email protected]> * Add AMP support Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Do not save autograd context if grad mode is disabled Debugging ONNX export tests. Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Forward args in pre_forward func to base op class Signed-off-by: Tim Moon <[email protected]> * Update to use QuantizedTensor class Signed-off-by: Tim Moon <[email protected]> * Apply suggestions from code review Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review suggestions from @ptrendx Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg. Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Use weight dtype as default compute dtype Signed-off-by: Tim Moon <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]>
1 parent 5ac4946 commit 722a89b

File tree

12 files changed

+1416
-511
lines changed

12 files changed

+1416
-511
lines changed

tests/pytorch/test_fusible_ops.py

Lines changed: 411 additions & 104 deletions
Large diffs are not rendered by default.

transformer_engine/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _load_library():
8282
from transformer_engine.pytorch.distributed import checkpoint
8383
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
8484
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
85+
from transformer_engine.pytorch import ops
8586
from transformer_engine.pytorch import optimizers
8687

8788
# Register custom op symbolic ONNX functions

transformer_engine/pytorch/module/layernorm.py

Lines changed: 112 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -3,158 +3,90 @@
33
# See LICENSE for license information.
44

55
"""LayerNorm API"""
6-
import os
76
import warnings
8-
from typing import Union, Tuple, Optional
7+
from typing import Iterable, Optional, Union
98

109
import torch
11-
from torch.nn.parameter import Parameter
12-
from torch.nn import init
1310

14-
import transformer_engine_torch as tex
15-
from ..cpp_extensions import (
16-
layernorm_fwd_inf,
17-
)
18-
from ..jit import no_torch_dynamo
19-
from ..utils import cast_if_needed
11+
from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp
2012

2113
__all__ = ["LayerNorm"]
2214

2315

24-
class _LayerNorm(torch.autograd.Function):
25-
"""functional LayerNorm"""
26-
27-
@staticmethod
28-
def forward(
29-
ctx,
30-
inp: torch.Tensor,
31-
ln_weight: torch.Tensor,
32-
ln_bias: torch.Tensor,
33-
eps: float,
34-
fwd_ln_sm_margin: int,
35-
bwd_ln_sm_margin: int,
36-
inf_ln_sm_margin: int,
37-
zero_centered_gamma: bool,
38-
is_grad_enabled: bool,
39-
activation_dtype: torch.dtype,
40-
) -> torch.Tensor:
41-
# pylint: disable=missing-function-docstring
42-
# Make sure input dimensions are compatible
43-
in_features = ln_weight.numel()
44-
assert inp.is_cuda, "TransformerEngine needs CUDA."
45-
assert inp.shape[-1] == in_features, "LayerNorm not possible"
46-
inputmat = inp.view((-1, in_features))
47-
48-
# Cast for native AMP
49-
inputmat = cast_if_needed(inputmat, activation_dtype)
50-
ln_weight = cast_if_needed(ln_weight, activation_dtype)
51-
ln_bias = cast_if_needed(ln_bias, activation_dtype)
52-
53-
if is_grad_enabled:
54-
ln_out, mu, rsigma = tex.layernorm_fwd(
55-
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
56-
)
57-
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
58-
ctx.inp_shape = inp.shape
59-
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
60-
ctx.zero_centered_gamma = zero_centered_gamma
61-
else:
62-
ln_out, mu, rsigma = (
63-
layernorm_fwd_inf(
64-
inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma
65-
),
66-
None,
67-
None,
68-
)
69-
return ln_out.view_as(inp)
70-
71-
@staticmethod
72-
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
73-
# pylint: disable=missing-function-docstring
74-
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
75-
grad_output = grad_output.contiguous()
76-
d_ln_out = grad_output.view(inputmat.shape)
77-
dxmat, dgamma, dbeta = tex.layernorm_bwd(
78-
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
79-
)
80-
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None
81-
16+
class LayerNorm(_LayerNormOp):
17+
r"""Layer Normalization
8218
83-
class LayerNorm(torch.nn.Module):
84-
r"""
8519
Applies Layer Normalization over a mini-batch of inputs as described in
8620
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
8721
8822
.. math::
89-
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
23+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
9024
91-
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
92-
size :attr:`hidden_size`
25+
:math:`\gamma` and :math:`\beta` are learnable affine transform
26+
parameters that match the inner-most dimensions of the input
27+
tensor.
9328
9429
Parameters
9530
----------
96-
hidden_size : int
97-
size of each input sample.
31+
normalized_shape: int or iterable of int
32+
Inner dimensions of input tensor
9833
eps : float, default = 1e-5
99-
a value added to the denominator of layer normalization for numerical stability.
100-
sequence_parallel : bool, default = `False`
101-
if set to `True`, uses sequence parallelism.
102-
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
103-
it controls the type used to allocate the initial parameters. Useful when
104-
the model is trained with lower precision and the original FP32 parameters
105-
would not fit in GPU memory.
34+
A value added to the denominator of layer normalization for
35+
numerical stability
36+
device: torch.device, default = default CUDA device
37+
Tensor device
38+
dtype: torch.dtype, default = default dtype
39+
Tensor datatype
10640
zero_centered_gamma : bool, default = 'False'
107-
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
108-
the LayerNorm formula changes to
109-
110-
.. math::
111-
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
112-
(1 + \gamma) + \beta
113-
device : Union[torch.device, str], default = "cuda"
114-
The device on which the parameters of the model will be allocated. It is the user's
115-
responsibility to ensure all parameters are moved to the GPU before running the
116-
forward pass.
41+
If `True`, the :math:`\gamma` parameter is initialized to zero
42+
and the calculation changes to
43+
44+
.. math::
45+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
46+
47+
sm_margin: int or dict, default = 0
48+
Number of SMs to exclude when launching CUDA kernels. This
49+
helps overlap with other kernels, e.g. communication kernels.
50+
For more fine-grained control, provide a dict with the SM
51+
margin at each compute stage ("forward", "backward",
52+
"inference").
53+
54+
Legacy
55+
------
56+
sequence_parallel: bool
57+
Set a bool attr named `sequence_parallel` in the parameters.
58+
This is custom logic for Megatron-LM integration.
59+
11760
"""
11861

11962
def __init__(
12063
self,
121-
hidden_size: int,
64+
normalized_shape: Union[Iterable[int], int],
12265
eps: float = 1e-5,
123-
sequence_parallel: bool = False,
124-
params_dtype: Optional[torch.dtype] = None,
66+
sequence_parallel: Optional[bool] = None, # legacy
67+
params_dtype: Optional[torch.dtype] = None, # deprecated
12568
zero_centered_gamma: bool = False,
126-
device: Union[torch.device, str] = "cuda",
69+
**kwargs,
12770
) -> None:
128-
super().__init__()
129-
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
130-
self.eps = eps
131-
self.zero_centered_gamma = zero_centered_gamma
132-
self.weight = Parameter(
133-
torch.empty(
134-
hidden_size,
135-
device=device,
136-
dtype=params_dtype,
137-
)
138-
)
139-
self.bias = Parameter(
140-
torch.empty(
141-
hidden_size,
142-
device=device,
143-
dtype=params_dtype,
144-
)
145-
)
146-
self.sequence_parallel = sequence_parallel
147-
self.activation_dtype: Optional[torch.dtype] = None
14871

149-
self.reset_parameters(defer_init=device == "meta")
72+
# Handle deprecated options
73+
if params_dtype is not None:
74+
if "dtype" in kwargs:
75+
raise RuntimeError(
76+
"Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
77+
)
78+
kwargs["dtype"] = params_dtype
79+
80+
# Initialize layer norm operation
81+
super().__init__(
82+
normalized_shape,
83+
eps=eps,
84+
zero_centered_gamma=zero_centered_gamma,
85+
**kwargs,
86+
)
15087

151-
# These many SMs are subtracted from the total SM count when calling forward
152-
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
153-
# kernels from using all SMs in the device. This is useful for cases such as
154-
# communication overlap with LN.
155-
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
156-
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
157-
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
88+
# Flag for sequence parallelism (custom Megatron-LM integration)
89+
self.sequence_parallel: Optional[bool] = sequence_parallel
15890

15991
def reset_layer_norm_parameters(self) -> None:
16092
"""Init LN params"""
@@ -164,64 +96,62 @@ def reset_layer_norm_parameters(self) -> None:
16496
DeprecationWarning,
16597
stacklevel=2,
16698
)
167-
if not self.zero_centered_gamma:
168-
init.ones_(self.weight)
169-
else:
170-
init.zeros_(self.weight)
171-
init.zeros_(self.bias)
99+
self.reset_parameters()
172100

173-
def reset_parameters(self, defer_init=False) -> None:
101+
def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
174102
"""Init LayerNorm parameters"""
175-
if defer_init:
176-
return
177-
178-
if self.weight.device == torch.device("meta"):
179-
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda"))
180-
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
181-
init.constant_(self.weight, float(not self.zero_centered_gamma))
182-
183-
if self.bias.device == torch.device("meta"):
184-
self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda"))
185-
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
186-
init.zeros_(self.bias)
187-
188-
@no_torch_dynamo()
189-
def forward(self, inp: torch.Tensor) -> torch.Tensor:
190-
# pylint: disable=missing-function-docstring
191-
192-
# Set the activation type for AMP.
193-
# Note: This will soon be deprecated with
194-
# https://github.com/NVIDIA/TransformerEngine/pull/1033
195-
if torch.is_autocast_enabled():
196-
self.activation_dtype = torch.get_autocast_gpu_dtype()
197-
elif self.activation_dtype != inp.dtype:
198-
dtype = inp.dtype
199-
for name, param in self.named_parameters():
200-
if param is not None:
201-
assert dtype == param.dtype, (
202-
"Data types for parameters must match when outside of autocasted region. "
203-
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
204-
)
205-
self.activation_dtype = dtype
206-
207-
if torch.is_grad_enabled():
208-
fwd_fn = _LayerNorm.apply
209-
args = []
210-
else:
211-
fwd_fn = _LayerNorm.forward
212-
args = [None]
213-
214-
args += (
215-
inp,
216-
self.weight,
217-
self.bias,
218-
self.eps,
219-
self.fwd_ln_sm_margin,
220-
self.bwd_ln_sm_margin,
221-
self.inf_ln_sm_margin,
222-
self.zero_centered_gamma,
223-
torch.is_grad_enabled(),
224-
self.activation_dtype,
225-
)
226103

227-
return fwd_fn(*args)
104+
# Check whether to defer init (deprecated)
105+
if defer_init is not None:
106+
warnings.warn(
107+
"defer_init argument to reset_parameters function is deprecated. Set device to"
108+
' "meta" instead.',
109+
DeprecationWarning,
110+
stacklevel=2,
111+
)
112+
if defer_init:
113+
return
114+
115+
# Reset parameters
116+
super().reset_parameters()
117+
118+
# Set flag for sequence parallelism (custom Megatron-LM integration)
119+
if getattr(self, "sequence_parallel", None) is not None:
120+
self.weight.sequence_parallel = self.sequence_parallel
121+
self.bias.sequence_parallel = self.sequence_parallel
122+
123+
@property
124+
def fwd_ln_sm_margin(self) -> int:
125+
"""Shim for backward compatibility"""
126+
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
127+
return self._sm_margins["forward"]
128+
129+
@fwd_ln_sm_margin.setter
130+
def fwd_ln_sm_margin(self, val: int) -> None:
131+
"""Shim for backward compatibility"""
132+
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
133+
self._sm_margins["forward"] = val
134+
135+
@property
136+
def bwd_ln_sm_margin(self) -> int:
137+
"""Shim for backward compatibility"""
138+
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
139+
return self._sm_margins["backward"]
140+
141+
@bwd_ln_sm_margin.setter
142+
def bwd_ln_sm_margin(self, val: int) -> None:
143+
"""Shim for backward compatibility"""
144+
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
145+
self._sm_margins["backward"] = val
146+
147+
@property
148+
def inf_ln_sm_margin(self) -> int:
149+
"""Shim for backward compatibility"""
150+
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
151+
return self._sm_margins["inference"]
152+
153+
@inf_ln_sm_margin.setter
154+
def inf_ln_sm_margin(self, val: int) -> None:
155+
"""Shim for backward compatibility"""
156+
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
157+
self._sm_margins["inference"] = val

0 commit comments

Comments
 (0)