3
3
# See LICENSE for license information.
4
4
5
5
"""LayerNorm API"""
6
- import os
7
6
import warnings
8
- from typing import Union , Tuple , Optional
7
+ from typing import Iterable , Optional , Union
9
8
10
9
import torch
11
- from torch .nn .parameter import Parameter
12
- from torch .nn import init
13
10
14
- import transformer_engine_torch as tex
15
- from ..cpp_extensions import (
16
- layernorm_fwd_inf ,
17
- )
18
- from ..jit import no_torch_dynamo
19
- from ..utils import cast_if_needed
11
+ from transformer_engine .pytorch .ops import LayerNorm as _LayerNormOp
20
12
21
13
__all__ = ["LayerNorm" ]
22
14
23
15
24
- class _LayerNorm (torch .autograd .Function ):
25
- """functional LayerNorm"""
26
-
27
- @staticmethod
28
- def forward (
29
- ctx ,
30
- inp : torch .Tensor ,
31
- ln_weight : torch .Tensor ,
32
- ln_bias : torch .Tensor ,
33
- eps : float ,
34
- fwd_ln_sm_margin : int ,
35
- bwd_ln_sm_margin : int ,
36
- inf_ln_sm_margin : int ,
37
- zero_centered_gamma : bool ,
38
- is_grad_enabled : bool ,
39
- activation_dtype : torch .dtype ,
40
- ) -> torch .Tensor :
41
- # pylint: disable=missing-function-docstring
42
- # Make sure input dimensions are compatible
43
- in_features = ln_weight .numel ()
44
- assert inp .is_cuda , "TransformerEngine needs CUDA."
45
- assert inp .shape [- 1 ] == in_features , "LayerNorm not possible"
46
- inputmat = inp .view ((- 1 , in_features ))
47
-
48
- # Cast for native AMP
49
- inputmat = cast_if_needed (inputmat , activation_dtype )
50
- ln_weight = cast_if_needed (ln_weight , activation_dtype )
51
- ln_bias = cast_if_needed (ln_bias , activation_dtype )
52
-
53
- if is_grad_enabled :
54
- ln_out , mu , rsigma = tex .layernorm_fwd (
55
- inputmat , ln_weight , ln_bias , eps , fwd_ln_sm_margin , zero_centered_gamma
56
- )
57
- ctx .save_for_backward (inputmat , ln_weight , mu , rsigma )
58
- ctx .inp_shape = inp .shape
59
- ctx .bwd_ln_sm_margin = bwd_ln_sm_margin
60
- ctx .zero_centered_gamma = zero_centered_gamma
61
- else :
62
- ln_out , mu , rsigma = (
63
- layernorm_fwd_inf (
64
- inputmat , ln_weight , ln_bias , eps , inf_ln_sm_margin , zero_centered_gamma
65
- ),
66
- None ,
67
- None ,
68
- )
69
- return ln_out .view_as (inp )
70
-
71
- @staticmethod
72
- def backward (ctx , grad_output : torch .Tensor ) -> Tuple [Union [torch .Tensor , None ], ...]:
73
- # pylint: disable=missing-function-docstring
74
- inputmat , ln_weight , mu , rsigma = ctx .saved_tensors
75
- grad_output = grad_output .contiguous ()
76
- d_ln_out = grad_output .view (inputmat .shape )
77
- dxmat , dgamma , dbeta = tex .layernorm_bwd (
78
- d_ln_out , inputmat , mu , rsigma , ln_weight , ctx .bwd_ln_sm_margin , ctx .zero_centered_gamma
79
- )
80
- return dxmat .view (ctx .inp_shape ), dgamma , dbeta , None , None , None , None , None , None , None
81
-
16
+ class LayerNorm (_LayerNormOp ):
17
+ r"""Layer Normalization
82
18
83
- class LayerNorm (torch .nn .Module ):
84
- r"""
85
19
Applies Layer Normalization over a mini-batch of inputs as described in
86
20
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
87
21
88
22
.. math::
89
- y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
23
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
90
24
91
- :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
92
- size :attr:`hidden_size`
25
+ :math:`\gamma` and :math:`\beta` are learnable affine transform
26
+ parameters that match the inner-most dimensions of the input
27
+ tensor.
93
28
94
29
Parameters
95
30
----------
96
- hidden_size : int
97
- size of each input sample.
31
+ normalized_shape: int or iterable of int
32
+ Inner dimensions of input tensor
98
33
eps : float, default = 1e-5
99
- a value added to the denominator of layer normalization for numerical stability.
100
- sequence_parallel : bool, default = `False`
101
- if set to `True`, uses sequence parallelism.
102
- params_dtype : torch.dtype, default = `torch.get_default_dtype()`
103
- it controls the type used to allocate the initial parameters. Useful when
104
- the model is trained with lower precision and the original FP32 parameters
105
- would not fit in GPU memory.
34
+ A value added to the denominator of layer normalization for
35
+ numerical stability
36
+ device: torch.device, default = default CUDA device
37
+ Tensor device
38
+ dtype: torch.dtype, default = default dtype
39
+ Tensor datatype
106
40
zero_centered_gamma : bool, default = 'False'
107
- if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
108
- the LayerNorm formula changes to
109
-
110
- .. math::
111
- y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
112
- (1 + \gamma) + \beta
113
- device : Union[torch.device, str], default = "cuda"
114
- The device on which the parameters of the model will be allocated. It is the user's
115
- responsibility to ensure all parameters are moved to the GPU before running the
116
- forward pass.
41
+ If `True`, the :math:`\gamma` parameter is initialized to zero
42
+ and the calculation changes to
43
+
44
+ .. math::
45
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
46
+
47
+ sm_margin: int or dict, default = 0
48
+ Number of SMs to exclude when launching CUDA kernels. This
49
+ helps overlap with other kernels, e.g. communication kernels.
50
+ For more fine-grained control, provide a dict with the SM
51
+ margin at each compute stage ("forward", "backward",
52
+ "inference").
53
+
54
+ Legacy
55
+ ------
56
+ sequence_parallel: bool
57
+ Set a bool attr named `sequence_parallel` in the parameters.
58
+ This is custom logic for Megatron-LM integration.
59
+
117
60
"""
118
61
119
62
def __init__ (
120
63
self ,
121
- hidden_size : int ,
64
+ normalized_shape : Union [ Iterable [ int ], int ] ,
122
65
eps : float = 1e-5 ,
123
- sequence_parallel : bool = False ,
124
- params_dtype : Optional [torch .dtype ] = None ,
66
+ sequence_parallel : Optional [ bool ] = None , # legacy
67
+ params_dtype : Optional [torch .dtype ] = None , # deprecated
125
68
zero_centered_gamma : bool = False ,
126
- device : Union [ torch . device , str ] = "cuda" ,
69
+ ** kwargs ,
127
70
) -> None :
128
- super ().__init__ ()
129
- params_dtype = torch .get_default_dtype () if params_dtype is None else params_dtype
130
- self .eps = eps
131
- self .zero_centered_gamma = zero_centered_gamma
132
- self .weight = Parameter (
133
- torch .empty (
134
- hidden_size ,
135
- device = device ,
136
- dtype = params_dtype ,
137
- )
138
- )
139
- self .bias = Parameter (
140
- torch .empty (
141
- hidden_size ,
142
- device = device ,
143
- dtype = params_dtype ,
144
- )
145
- )
146
- self .sequence_parallel = sequence_parallel
147
- self .activation_dtype : Optional [torch .dtype ] = None
148
71
149
- self .reset_parameters (defer_init = device == "meta" )
72
+ # Handle deprecated options
73
+ if params_dtype is not None :
74
+ if "dtype" in kwargs :
75
+ raise RuntimeError (
76
+ "Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
77
+ )
78
+ kwargs ["dtype" ] = params_dtype
79
+
80
+ # Initialize layer norm operation
81
+ super ().__init__ (
82
+ normalized_shape ,
83
+ eps = eps ,
84
+ zero_centered_gamma = zero_centered_gamma ,
85
+ ** kwargs ,
86
+ )
150
87
151
- # These many SMs are subtracted from the total SM count when calling forward
152
- # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
153
- # kernels from using all SMs in the device. This is useful for cases such as
154
- # communication overlap with LN.
155
- self .fwd_ln_sm_margin = int (os .getenv ("NVTE_FWD_LAYERNORM_SM_MARGIN" , "0" ))
156
- self .bwd_ln_sm_margin = int (os .getenv ("NVTE_BWD_LAYERNORM_SM_MARGIN" , "0" ))
157
- self .inf_ln_sm_margin = int (os .getenv ("NVTE_INF_LAYERNORM_SM_MARGIN" , "0" ))
88
+ # Flag for sequence parallelism (custom Megatron-LM integration)
89
+ self .sequence_parallel : Optional [bool ] = sequence_parallel
158
90
159
91
def reset_layer_norm_parameters (self ) -> None :
160
92
"""Init LN params"""
@@ -164,64 +96,62 @@ def reset_layer_norm_parameters(self) -> None:
164
96
DeprecationWarning ,
165
97
stacklevel = 2 ,
166
98
)
167
- if not self .zero_centered_gamma :
168
- init .ones_ (self .weight )
169
- else :
170
- init .zeros_ (self .weight )
171
- init .zeros_ (self .bias )
99
+ self .reset_parameters ()
172
100
173
- def reset_parameters (self , defer_init = False ) -> None :
101
+ def reset_parameters (self , defer_init : Optional [ bool ] = None ) -> None :
174
102
"""Init LayerNorm parameters"""
175
- if defer_init :
176
- return
177
-
178
- if self .weight .device == torch .device ("meta" ):
179
- self .weight = torch .nn .Parameter (torch .empty_like (self .weight , device = "cuda" ))
180
- setattr (self .weight , "sequence_parallel" , self .sequence_parallel )
181
- init .constant_ (self .weight , float (not self .zero_centered_gamma ))
182
-
183
- if self .bias .device == torch .device ("meta" ):
184
- self .bias = torch .nn .Parameter (torch .empty_like (self .bias , device = "cuda" ))
185
- setattr (self .bias , "sequence_parallel" , self .sequence_parallel )
186
- init .zeros_ (self .bias )
187
-
188
- @no_torch_dynamo ()
189
- def forward (self , inp : torch .Tensor ) -> torch .Tensor :
190
- # pylint: disable=missing-function-docstring
191
-
192
- # Set the activation type for AMP.
193
- # Note: This will soon be deprecated with
194
- # https://github.com/NVIDIA/TransformerEngine/pull/1033
195
- if torch .is_autocast_enabled ():
196
- self .activation_dtype = torch .get_autocast_gpu_dtype ()
197
- elif self .activation_dtype != inp .dtype :
198
- dtype = inp .dtype
199
- for name , param in self .named_parameters ():
200
- if param is not None :
201
- assert dtype == param .dtype , (
202
- "Data types for parameters must match when outside of autocasted region. "
203
- f" Found input dtype: { dtype } and { name !r} dtype: { param .dtype } "
204
- )
205
- self .activation_dtype = dtype
206
-
207
- if torch .is_grad_enabled ():
208
- fwd_fn = _LayerNorm .apply
209
- args = []
210
- else :
211
- fwd_fn = _LayerNorm .forward
212
- args = [None ]
213
-
214
- args += (
215
- inp ,
216
- self .weight ,
217
- self .bias ,
218
- self .eps ,
219
- self .fwd_ln_sm_margin ,
220
- self .bwd_ln_sm_margin ,
221
- self .inf_ln_sm_margin ,
222
- self .zero_centered_gamma ,
223
- torch .is_grad_enabled (),
224
- self .activation_dtype ,
225
- )
226
103
227
- return fwd_fn (* args )
104
+ # Check whether to defer init (deprecated)
105
+ if defer_init is not None :
106
+ warnings .warn (
107
+ "defer_init argument to reset_parameters function is deprecated. Set device to"
108
+ ' "meta" instead.' ,
109
+ DeprecationWarning ,
110
+ stacklevel = 2 ,
111
+ )
112
+ if defer_init :
113
+ return
114
+
115
+ # Reset parameters
116
+ super ().reset_parameters ()
117
+
118
+ # Set flag for sequence parallelism (custom Megatron-LM integration)
119
+ if getattr (self , "sequence_parallel" , None ) is not None :
120
+ self .weight .sequence_parallel = self .sequence_parallel
121
+ self .bias .sequence_parallel = self .sequence_parallel
122
+
123
+ @property
124
+ def fwd_ln_sm_margin (self ) -> int :
125
+ """Shim for backward compatibility"""
126
+ warnings .warn ("fwd_ln_sm_margin attr is deprecated" , DeprecationWarning , stacklevel = 2 )
127
+ return self ._sm_margins ["forward" ]
128
+
129
+ @fwd_ln_sm_margin .setter
130
+ def fwd_ln_sm_margin (self , val : int ) -> None :
131
+ """Shim for backward compatibility"""
132
+ warnings .warn ("fwd_ln_sm_margin attr is deprecated" , DeprecationWarning , stacklevel = 2 )
133
+ self ._sm_margins ["forward" ] = val
134
+
135
+ @property
136
+ def bwd_ln_sm_margin (self ) -> int :
137
+ """Shim for backward compatibility"""
138
+ warnings .warn ("bwd_ln_sm_margin attr is deprecated" , DeprecationWarning , stacklevel = 2 )
139
+ return self ._sm_margins ["backward" ]
140
+
141
+ @bwd_ln_sm_margin .setter
142
+ def bwd_ln_sm_margin (self , val : int ) -> None :
143
+ """Shim for backward compatibility"""
144
+ warnings .warn ("bwd_ln_sm_margin attr is deprecated" , DeprecationWarning , stacklevel = 2 )
145
+ self ._sm_margins ["backward" ] = val
146
+
147
+ @property
148
+ def inf_ln_sm_margin (self ) -> int :
149
+ """Shim for backward compatibility"""
150
+ warnings .warn ("inf_ln_sm_margin attr is deprecated" , DeprecationWarning , stacklevel = 2 )
151
+ return self ._sm_margins ["inference" ]
152
+
153
+ @inf_ln_sm_margin .setter
154
+ def inf_ln_sm_margin (self , val : int ) -> None :
155
+ """Shim for backward compatibility"""
156
+ warnings .warn ("inf_ln_sm_margin attr is deprecated" , DeprecationWarning , stacklevel = 2 )
157
+ self ._sm_margins ["inference" ] = val
0 commit comments