From 77c37d49f3f2354fb9edb13ae72fa01dd39ae4b5 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:16:54 -0800 Subject: [PATCH] [PyTorch] Normalization ops (#1033) * Add layer norm op Signed-off-by: Tim Moon * Add FP8 cast op Signed-off-by: Tim Moon * Add tests for linear and layernorm with FP8 output Signed-off-by: Tim Moon * RMSNorm op Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon * Replace LayerNorm module with LayerNorm op Signed-off-by: Tim Moon * Replace RMSNorm module with RMSNorm op Signed-off-by: Tim Moon * Add AMP support Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Do not save autograd context if grad mode is disabled Debugging ONNX export tests. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Forward args in pre_forward func to base op class Signed-off-by: Tim Moon * Update to use QuantizedTensor class Signed-off-by: Tim Moon * Apply suggestions from code review Co-authored-by: Przemyslaw Tredak Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review suggestions from @ptrendx Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon * Use weight dtype as default compute dtype Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak --- tests/pytorch/test_fusible_ops.py | 515 ++++++++++++++---- transformer_engine/pytorch/__init__.py | 1 + .../pytorch/module/layernorm.py | 294 ++++------ transformer_engine/pytorch/module/rmsnorm.py | 297 ++++------ transformer_engine/pytorch/ops/__init__.py | 12 +- transformer_engine/pytorch/ops/_common.py | 22 +- .../pytorch/ops/basic/__init__.py | 3 + .../pytorch/ops/basic/layer_norm.py | 317 +++++++++++ .../pytorch/ops/basic/quantize.py | 93 ++++ .../pytorch/ops/basic/rmsnorm.py | 289 ++++++++++ transformer_engine/pytorch/ops/fuser.py | 82 +-- transformer_engine/pytorch/ops/op.py | 2 +- 12 files changed, 1416 insertions(+), 511 deletions(-) create mode 100644 transformer_engine/pytorch/ops/basic/layer_norm.py create mode 100644 transformer_engine/pytorch/ops/basic/quantize.py create mode 100644 transformer_engine/pytorch/ops/basic/rmsnorm.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1d91683ae4..29829ac4ac 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -10,6 +10,7 @@ import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager @@ -633,28 +634,78 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) - @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) - @pytest.mark.parametrize("fp8_grad_output", (False, True)) - @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) - def test_basic_linear( + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("cast_forward", (False, True)) + @pytest.mark.parametrize("cast_backward", (False, True)) + def test_cast_float8( self, *, - weight_shape: tuple[int, int], - in_shape: Iterable[int], - dtype: torch.dtype, + in_shape: Iterable[int] = (1,), + dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, - fp8_grad_output: bool, - accumulate_into_main_grad: bool, + cast_forward: bool, + cast_backward: bool, ) -> None: - """GEMM""" + """FP8 cast""" + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + test_is_fp8=True, + ) + x_test = x_test.from_float8().requires_grad_() + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + test_is_fp8=True, + ) + dy_test = dy_test.from_float8() + + # Plain PyTorch implementation + y_ref = x_ref + dx_ref = dy_ref + + # Implementation with fusible operation + op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + with te.fp8_autocast(fp8_recipe=recipe): + y_test = op(x_test) + y_test.backward(dy_test) + + # Check tensor types + assert is_float8_tensor(y_test) == cast_forward + assert is_float8_tensor(x_test.grad) == cast_backward + + # Check values + tols = dict(rtol=0, atol=0) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + + def _test_basic_linear( + self, + *, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8_compute: bool = False, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_output: bool = False, + fp8_grad_output: bool = False, + fp8_grad_input: bool = False, + accumulate_into_main_grad: bool = False, + ) -> None: + """Helper function for tests with GEMM""" # Make input and weight shapes consistent out_features, in_features = weight_shape @@ -662,7 +713,7 @@ def test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_compute or fp8_input or fp8_weight or fp8_grad_output: + if fp8_compute or fp8_input or fp8_weight or fp8_output or fp8_grad_output: if not fp8_available: pytest.skip(reason_for_no_fp8) if torch.device(device).type != "cuda": @@ -674,6 +725,10 @@ def test_basic_linear( or out_features % 16 != 0 ): pytest.skip("FP8 GEMMs require dims that are divisible by 16") + if fp8_output and not fp8_compute: + pytest.skip("FP8 output is only supported with FP8 GEMMs") + if fp8_grad_input and not fp8_compute: + pytest.skip("FP8 grad input is only supported with FP8 GEMMs") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -713,15 +768,23 @@ def test_basic_linear( op.weight.copy_(w_test) del w_test op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) - with te.fp8_autocast(enabled=fp8_compute): - y_test = op(x_test) + forward = te_ops.Sequential( + te_ops.Quantize(forward=fp8_input, backward=fp8_grad_input), + op, + te_ops.Quantize(forward=fp8_output, backward=fp8_grad_output), + ) + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + with te.fp8_autocast(enabled=fp8_compute, fp8_recipe=recipe): + y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: + if fp8_compute or fp8_output or fp8_grad_input: tols = dtype_tols( op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 ) @@ -750,6 +813,57 @@ def test_basic_linear( ) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + def test_basic_linear( + self, + *, + weight_shape: tuple[int, int], + in_shape: Iterable[int], + dtype: torch.dtype, + fp8_compute: bool, + accumulate_into_main_grad: bool, + ) -> None: + """GEMM""" + self._test_basic_linear( + weight_shape=weight_shape, + in_shape=in_shape, + dtype=dtype, + fp8_compute=fp8_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) + def test_basic_linear_fp8( + self, + *, + fp8_compute: bool, + fp8_input: bool, + fp8_weight: bool, + fp8_output: bool, + fp8_grad_output: bool, + fp8_grad_input: bool, + ) -> None: + """GEMM with FP8 inputs and outputs""" + self._test_basic_linear( + dtype=torch.bfloat16, + fp8_compute=fp8_compute, + fp8_input=fp8_input, + fp8_weight=fp8_weight, + fp8_output=fp8_output, + fp8_grad_output=fp8_grad_output, + fp8_grad_input=fp8_grad_input, + ) + @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("fp8_compute", (False, True)) @pytest.mark.parametrize("fp8_weight", (False, True)) @@ -856,6 +970,271 @@ def test_linear( db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("zero_centered_gamma", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_layer_norm( + self, + *, + weight_shape: Iterable[int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 0.3, + zero_centered_gamma: bool, + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Layer norm""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + b_ref, b_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.layer_norm( + x_ref, + weight_shape, + weight=(w_ref + 1 if zero_centered_gamma else w_ref), + bias=b_ref, + eps=eps, + ) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.LayerNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + op.bias.copy_(b_test) + del w_test + del b_test + forward = te_ops.Sequential( + op, + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + torch.testing.assert_close(db_test, b_ref.grad, **tols) + + def test_layer_norm_autocast( + self, + *, + weight_shape: Iterable[int] = (32,), + in_shape: Iterable[int] = (32,), + dtype: torch.dtype = torch.float16, + autocast_dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + eps: float = 0.3, + ) -> None: + """Layer norm with PyTorch autocast""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=autocast_dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + b_ref, b_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=autocast_dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.layer_norm( + x_ref, + weight_shape, + weight=w_ref, + bias=b_ref, + eps=eps, + ) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.LayerNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + op.bias.copy_(b_test) + del w_test + del b_test + with torch.autocast(device, dtype=autocast_dtype): + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + assert y_test.dtype == autocast_dtype + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **dtype_tols(autocast_dtype)) + torch.testing.assert_close(dx_test, x_ref.grad, **dtype_tols(autocast_dtype)) + torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype)) + torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype)) + + @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("zero_centered_gamma", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_rmsnorm( + self, + *, + weight_shape: Iterable[int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 0.3, + zero_centered_gamma: bool, + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Layer norm""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape))) + var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape) + if zero_centered_gamma: + y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref) + else: + y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.RMSNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + del w_test + forward = te_ops.Sequential( + op, + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("fp8", (False, True)) @@ -867,6 +1246,11 @@ def test_add_in_place( device: torch.device, fp8: bool, ) -> None: + """Add two tensors + + Join in compute graph. + + """ # Skip invalid configurations if fp8 and not fp8_available: @@ -927,6 +1311,11 @@ def test_make_extra_output( device: torch.device, fp8: bool, ) -> None: + """Output tensor twice + + Split in compute graph. + + """ # Skip invalid configurations if fp8 and not fp8_available: @@ -1106,88 +1495,6 @@ def test_forward_linear_bias_activation( db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - def test_fp8_linear( - self, - *, - in_shape: Iterable[int] = (16, 16), - dtype: torch.dtype = torch.bfloat16, - device: torch.device = "cuda", - ) -> None: - """Adjacent linear ops with FP8 enabled""" - - # Make input and weight shapes consistent - in_shape = tuple(in_shape) - weight_shape = (in_shape[-1], in_shape[-1]) - - # Random data - x_ref, x_test = make_reference_and_test_tensors( - in_shape, - test_dtype=dtype, - test_device=device, - test_is_fp8=True, - ) - w0_ref, w0_test = make_reference_and_test_tensors( - weight_shape, - test_dtype=dtype, - test_device=device, - test_is_fp8=True, - ) - w1_ref, w1_test = make_reference_and_test_tensors( - weight_shape, - test_dtype=dtype, - test_device=device, - test_is_fp8=True, - ) - dy_ref, dy_test = make_reference_and_test_tensors( - in_shape, - test_dtype=dtype, - test_device=device, - requires_grad=False, - ) - - # Plain PyTorch implementation - y_ref = torch.nn.functional.linear(x_ref, w0_ref) - y_ref = torch.nn.functional.linear(y_ref, w1_ref) - y_ref.backward(dy_ref) - - # Implementation with fusible operations - with te.fp8_model_init(enabled=True): - model = te_ops.Sequential( - te_ops.BasicLinear( - in_shape[-1], - in_shape[-1], - device=device, - dtype=dtype, - ), - te_ops.BasicLinear( - in_shape[-1], - in_shape[-1], - device=device, - dtype=dtype, - ), - ) - with torch.no_grad(): - model[0].weight.copy_(w0_test) - model[1].weight.copy_(w1_test) - del w0_test, w1_test - with te.fp8_autocast(enabled=True): - y_test = model(x_test) - y_test.backward(dy_test) - - # Expected numerical error - tols = dtype_tols(model[0].weight._fp8_dtype) - - # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - dw0_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") - dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) - torch.testing.assert_close(dw0_test, w0_ref.grad, **tols) - torch.testing.assert_close(dw1_test, w1_ref.grad, **tols) - @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("fp8_compute", (False, True)) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index c4097333d3..781f9d42fd 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -82,6 +82,7 @@ def _load_library(): from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context +from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers # Register custom op symbolic ONNX functions diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 0c439ac417..32142cf48c 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -3,158 +3,90 @@ # See LICENSE for license information. """LayerNorm API""" -import os import warnings -from typing import Union, Tuple, Optional +from typing import Iterable, Optional, Union import torch -from torch.nn.parameter import Parameter -from torch.nn import init -import transformer_engine_torch as tex -from ..cpp_extensions import ( - layernorm_fwd_inf, -) -from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp __all__ = ["LayerNorm"] -class _LayerNorm(torch.autograd.Function): - """functional LayerNorm""" - - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - ln_weight: torch.Tensor, - ln_bias: torch.Tensor, - eps: float, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - inf_ln_sm_margin: int, - zero_centered_gamma: bool, - is_grad_enabled: bool, - activation_dtype: torch.dtype, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - # Make sure input dimensions are compatible - in_features = ln_weight.numel() - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert inp.shape[-1] == in_features, "LayerNorm not possible" - inputmat = inp.view((-1, in_features)) - - # Cast for native AMP - inputmat = cast_if_needed(inputmat, activation_dtype) - ln_weight = cast_if_needed(ln_weight, activation_dtype) - ln_bias = cast_if_needed(ln_bias, activation_dtype) - - if is_grad_enabled: - ln_out, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) - ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - else: - ln_out, mu, rsigma = ( - layernorm_fwd_inf( - inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma - ), - None, - None, - ) - return ln_out.view_as(inp) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - inputmat, ln_weight, mu, rsigma = ctx.saved_tensors - grad_output = grad_output.contiguous() - d_ln_out = grad_output.view(inputmat.shape) - dxmat, dgamma, dbeta = tex.layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma - ) - return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None - +class LayerNorm(_LayerNormOp): + r"""Layer Normalization -class LayerNorm(torch.nn.Module): - r""" Applies Layer Normalization over a mini-batch of inputs as described in the paper `Layer Normalization `__ .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta - :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of - size :attr:`hidden_size` + :math:`\gamma` and :math:`\beta` are learnable affine transform + parameters that match the inner-most dimensions of the input + tensor. Parameters ---------- - hidden_size : int - size of each input sample. + normalized_shape: int or iterable of int + Inner dimensions of input tensor eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. + A value added to the denominator of layer normalization for + numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will be allocated. It is the user's - responsibility to ensure all parameters are moved to the GPU before running the - forward pass. + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta + + sm_margin: int or dict, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + Legacy + ------ + sequence_parallel: bool + Set a bool attr named `sequence_parallel` in the parameters. + This is custom logic for Megatron-LM integration. + """ def __init__( self, - hidden_size: int, + normalized_shape: Union[Iterable[int], int], eps: float = 1e-5, - sequence_parallel: bool = False, - params_dtype: Optional[torch.dtype] = None, + sequence_parallel: Optional[bool] = None, # legacy + params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, - device: Union[torch.device, str] = "cuda", + **kwargs, ) -> None: - super().__init__() - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.weight = Parameter( - torch.empty( - hidden_size, - device=device, - dtype=params_dtype, - ) - ) - self.bias = Parameter( - torch.empty( - hidden_size, - device=device, - dtype=params_dtype, - ) - ) - self.sequence_parallel = sequence_parallel - self.activation_dtype: Optional[torch.dtype] = None - self.reset_parameters(defer_init=device == "meta") + # Handle deprecated options + if params_dtype is not None: + if "dtype" in kwargs: + raise RuntimeError( + "Both `dtype` and `params_dtype` (deprecated) kwargs are provided" + ) + kwargs["dtype"] = params_dtype + + # Initialize layer norm operation + super().__init__( + normalized_shape, + eps=eps, + zero_centered_gamma=zero_centered_gamma, + **kwargs, + ) - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + # Flag for sequence parallelism (custom Megatron-LM integration) + self.sequence_parallel: Optional[bool] = sequence_parallel def reset_layer_norm_parameters(self) -> None: """Init LN params""" @@ -164,64 +96,62 @@ def reset_layer_norm_parameters(self) -> None: DeprecationWarning, stacklevel=2, ) - if not self.zero_centered_gamma: - init.ones_(self.weight) - else: - init.zeros_(self.weight) - init.zeros_(self.bias) + self.reset_parameters() - def reset_parameters(self, defer_init=False) -> None: + def reset_parameters(self, defer_init: Optional[bool] = None) -> None: """Init LayerNorm parameters""" - if defer_init: - return - - if self.weight.device == torch.device("meta"): - self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) - setattr(self.weight, "sequence_parallel", self.sequence_parallel) - init.constant_(self.weight, float(not self.zero_centered_gamma)) - - if self.bias.device == torch.device("meta"): - self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda")) - setattr(self.bias, "sequence_parallel", self.sequence_parallel) - init.zeros_(self.bias) - - @no_torch_dynamo() - def forward(self, inp: torch.Tensor) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Set the activation type for AMP. - # Note: This will soon be deprecated with - # https://github.com/NVIDIA/TransformerEngine/pull/1033 - if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() - elif self.activation_dtype != inp.dtype: - dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - self.activation_dtype = dtype - - if torch.is_grad_enabled(): - fwd_fn = _LayerNorm.apply - args = [] - else: - fwd_fn = _LayerNorm.forward - args = [None] - - args += ( - inp, - self.weight, - self.bias, - self.eps, - self.fwd_ln_sm_margin, - self.bwd_ln_sm_margin, - self.inf_ln_sm_margin, - self.zero_centered_gamma, - torch.is_grad_enabled(), - self.activation_dtype, - ) - return fwd_fn(*args) + # Check whether to defer init (deprecated) + if defer_init is not None: + warnings.warn( + "defer_init argument to reset_parameters function is deprecated. Set device to" + ' "meta" instead.', + DeprecationWarning, + stacklevel=2, + ) + if defer_init: + return + + # Reset parameters + super().reset_parameters() + + # Set flag for sequence parallelism (custom Megatron-LM integration) + if getattr(self, "sequence_parallel", None) is not None: + self.weight.sequence_parallel = self.sequence_parallel + self.bias.sequence_parallel = self.sequence_parallel + + @property + def fwd_ln_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["forward"] + + @fwd_ln_sm_margin.setter + def fwd_ln_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["forward"] = val + + @property + def bwd_ln_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["backward"] + + @bwd_ln_sm_margin.setter + def bwd_ln_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["backward"] = val + + @property + def inf_ln_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["inference"] + + @inf_ln_sm_margin.setter + def inf_ln_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["inference"] = val diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index fc6ec5746f..f3651ecc19 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -3,221 +3,158 @@ # See LICENSE for license information. """RMSNorm API""" -import os import warnings -from typing import Union, Tuple, Optional +from typing import Iterable, Optional, Union import torch -from torch.nn.parameter import Parameter -from torch.nn import init - -from .. import cpp_extensions as tex -from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from transformer_engine.pytorch.ops import RMSNorm as _RMSNormOp __all__ = ["RMSNorm"] -class _RMSNorm(torch.autograd.Function): - """functional RMSNorm""" - - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - rmsnorm_weight: torch.Tensor, - eps: float, - fwd_rmsnorm_sm_margin: int, - bwd_rmsnorm_sm_margin: int, - inf_rmsnorm_sm_margin: int, - zero_centered_gamma: bool, - is_grad_enabled: bool, - activation_dtype: torch.dtype, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - # Make sure input dimensions are compatible - in_features = rmsnorm_weight.numel() - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert inp.shape[-1] == in_features, "RMSNorm not possible" - inputmat = inp.view((-1, in_features)) - - # Cast for native AMP - inputmat = cast_if_needed(inputmat, activation_dtype) - rmsnorm_weight = cast_if_needed(rmsnorm_weight, activation_dtype) - - if is_grad_enabled: - rmsnorm_out, rsigma = tex.rmsnorm_fwd( - inputmat, rmsnorm_weight, eps, fwd_rmsnorm_sm_margin, zero_centered_gamma - ) - ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - else: - rmsnorm_out = tex.rmsnorm_fwd_inf( - inputmat, rmsnorm_weight, eps, inf_rmsnorm_sm_margin, zero_centered_gamma - ) - return rmsnorm_out.view_as(inp) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors - grad_output = grad_output.contiguous() - d_rmsnorm_out = grad_output.view(inputmat.shape) - dxmat, dgamma = tex.rmsnorm_bwd( - d_rmsnorm_out, - inputmat, - rsigma, - rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, - ctx.zero_centered_gamma, - ) - return ( - dxmat.view(ctx.inp_shape), - dgamma, - None, - None, - None, - None, - None, - None, - None, - ) - +class RMSNorm(_RMSNormOp): + r"""Root Mean Square Layer Normalization -class RMSNorm(torch.nn.Module): - r""" - Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in - the paper `Root Mean Square Layer Normalization `__ + Applies Root Mean Square Layer Normalization over a mini-batch of + inputs as described in the paper + `Root Mean Square Layer Normalization `__ .. math:: - y = \frac{x}{RMS_\varepsilon(x)} * \gamma + y = \frac{x}{\text{RMS}_\varepsilon(x)} * \gamma where .. math:: - RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon} + \text{RMS}_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^n x_i^2 + \varepsilon} - :math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` + :math:`\gamma` is a learnable affine transform parameter that + matches the inner-most dimensions of the input tensor. Parameters ---------- - hidden_size : int - size of each input sample. + normalized_shape: int or iterable of int + Inner dimensions of input tensor eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. + A value added to the denominator for numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in RMSNorm is initialized to 0 and - the RMSNorm formula changes to - - .. math:: - y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma) - device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will be allocated. It is the user's - responsibility to ensure all parameters are moved to the GPU before running the - forward pass. + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + + sm_margin: int, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + Legacy + ------ + sequence_parallel: bool + Set a bool attr named `sequence_parallel` in the parameters. + This is custom logic for Megatron-LM integration. + """ def __init__( self, - hidden_size: int, + normalized_shape: Union[Iterable[int], int], eps: float = 1e-5, - sequence_parallel: bool = False, - params_dtype: Optional[torch.dtype] = None, + sequence_parallel: Optional[bool] = None, # legacy + params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, - device: Union[torch.device, str] = "cuda", + **kwargs, ) -> None: - super().__init__() - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.weight = Parameter( - torch.empty( - hidden_size, - device=device, - dtype=params_dtype, - ) - ) - self.sequence_parallel = sequence_parallel - self.activation_dtype: Optional[torch.dtype] = None - self.reset_parameters(defer_init=device == "meta") + # Handle deprecated options + if params_dtype is not None: + if "dtype" in kwargs: + raise RuntimeError( + "Both `dtype` and `params_dtype` (deprecated) kwargs are provided" + ) + kwargs["dtype"] = params_dtype + + # Initialize RMSNorm operation + super().__init__( + normalized_shape, + eps=eps, + zero_centered_gamma=zero_centered_gamma, + **kwargs, + ) - # These many SMs are subtracted from the total SM count when calling forward - # and backward RMSNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with RMSNorm. - self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_rmsnorm_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + # Flag for sequence parallelism (custom Megatron-LM integration) + self.sequence_parallel: Optional[bool] = sequence_parallel def reset_rms_norm_parameters(self) -> None: - """Init RMSNorm params""" + """Deprecated""" warnings.warn( "This method is deprecated and will be removed in an upcoming release. " "Update your code to use RMSNorm.reset_parameters() instead.", DeprecationWarning, stacklevel=2, ) - if not self.zero_centered_gamma: - init.ones_(self.weight) - else: - init.zeros_(self.weight) - - def reset_parameters(self, defer_init=False) -> None: - """Reset RMSNorm parameters""" - if defer_init: - return - - if self.weight.device == torch.device("meta"): - self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) - init.constant_(self.weight, float(not self.zero_centered_gamma)) - setattr(self.weight, "sequence_parallel", self.sequence_parallel) - - @no_torch_dynamo() - def forward(self, inp: torch.Tensor) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Set the activation type for AMP. - # Note: This will soon be deprecated with - # https://github.com/NVIDIA/TransformerEngine/pull/1033 - if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() - elif self.activation_dtype != inp.dtype: - dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - self.activation_dtype = dtype - - if torch.is_grad_enabled(): - fwd_fn = _RMSNorm.apply - args = [] - else: - fwd_fn = _RMSNorm.forward - args = [None] - - args += ( - inp, - self.weight, - self.eps, - self.fwd_rmsnorm_sm_margin, - self.bwd_rmsnorm_sm_margin, - self.inf_rmsnorm_sm_margin, - self.zero_centered_gamma, - torch.is_grad_enabled(), - self.activation_dtype, - ) - - return fwd_fn(*args) + self.reset_parameters() + + def reset_parameters(self, defer_init: Optional[bool] = None) -> None: + """Init RMSNorm parameters""" + + # Check whether to defer init (deprecated) + if defer_init is not None: + warnings.warn( + "defer_init argument to reset_parameters function is deprecated. Set device to" + ' "meta" instead.', + DeprecationWarning, + stacklevel=2, + ) + if defer_init: + return + + # Reset parameters + super().reset_parameters() + + # Flag for sequence parallelism (custom Megatron-LM integration) + if getattr(self, "sequence_parallel", None) is not None: + self.weight.sequence_parallel = self.sequence_parallel + + @property + def fwd_rmsnorm_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["forward"] + + @fwd_rmsnorm_sm_margin.setter + def fwd_rmsnorm_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["forward"] = val + + @property + def bwd_rmsnorm_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["backward"] + + @bwd_rmsnorm_sm_margin.setter + def bwd_rmsnorm_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["backward"] = val + + @property + def inf_rmsnorm_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["inference"] + + @inf_rmsnorm_sm_margin.setter + def inf_rmsnorm_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["inference"] = val diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index f437f877b4..f65433398e 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,17 +8,7 @@ """ -from transformer_engine.pytorch.ops.basic import ( - AddInPlace, - AllGather, - AllReduce, - BasicLinear, - Bias, - Identity, - MakeExtraOutput, - ReduceScatter, - Reshape, -) +from transformer_engine.pytorch.ops.basic import * from transformer_engine.pytorch.ops.linear import Linear from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.sequential import Sequential diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 12270d8340..89a529a78e 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -56,6 +56,8 @@ def convert_tensor( if memory_format != torch.preserve_format and not data.is_contiguous( memory_format=memory_format ): + # Note: torch.Tensor.to ignores memory_format kwarg (see + # https://github.com/pytorch/pytorch/issues/132020). data = data.contiguous(memory_format=memory_format) return Float8Tensor.make_like( tensor, @@ -65,7 +67,14 @@ def convert_tensor( ) # Convert standard PyTorch tensor - return tensor.to(device=device, dtype=dtype, memory_format=memory_format) + tensor = tensor.to(device=device, dtype=dtype) + if memory_format != torch.preserve_format and not tensor.is_contiguous( + memory_format=memory_format + ): + # Note: torch.Tensor.to ignores memory_format kwarg (see + # https://github.com/pytorch/pytorch/issues/132020). + tensor = tensor.contiguous(memory_format=memory_format) + return tensor def reshape( @@ -114,3 +123,14 @@ def reshape( # Reshape standard PyTorch tensor return tensor.view(shape) + + +def maybe_autocast_dtype( + *, + device_type: str = "cuda", + default_dtype: Optional[torch.dtype] = None, +) -> torch.dtype: + """Get autocast dtype if enabled""" + if torch.is_autocast_enabled(device_type): + return torch.get_autocast_dtype(device_type) + return canonicalize_dtype(default_dtype) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 1003cc0337..3dd8f64229 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -10,6 +10,9 @@ from .basic_linear import BasicLinear from .bias import Bias from .identity import Identity +from .layer_norm import LayerNorm from .make_extra_output import MakeExtraOutput +from .quantize import Quantize from .reduce_scatter import ReduceScatter from .reshape import Reshape +from .rmsnorm import RMSNorm diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py new file mode 100644 index 0000000000..99c9c493db --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -0,0 +1,317 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusable operation for Layer Normalization.""" + +from __future__ import annotations +from collections.abc import Iterable +import math +import os +from typing import Optional + +import torch + +from transformer_engine_torch import layernorm_bwd, layernorm_fwd +from ...cpp_extensions import ( + layernorm_fwd_fp8, + layernorm_fwd_fp8_inf, + layernorm_fwd_inf, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_autocast_dtype, reshape + + +class LayerNorm(BasicOperation): + r"""Layer Normalization + + Applies Layer Normalization over a mini-batch of inputs as described in + the paper `Layer Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta + + :math:`\gamma` and :math:`\beta` are learnable affine transform + parameters that match the inner-most dimensions of the input + tensor. + + Parameters + ---------- + normalized_shape: int or iterable of int + Inner dimensions of input tensor + eps : float, default = 1e-5 + A value added to the denominator of layer normalization for + numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + zero_centered_gamma : bool, default = 'False' + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta + + sm_margin: int or dict, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + """ + + def __init__( + self, + normalized_shape: Iterable[int] | int, + *, + eps: float = 1e-5, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + zero_centered_gamma: bool = False, + sm_margin: int | dict[str, int] = 0, + ) -> None: + super().__init__() + self.eps: float = eps + self.zero_centered_gamma: bool = zero_centered_gamma + + # Parameter shape + if not isinstance(normalized_shape, Iterable): + normalized_shape = (normalized_shape,) + else: + normalized_shape = tuple(normalized_shape) + self._shape: tuple[int, ...] = normalized_shape + + # Parameter device + defer_param_init = False + device = canonicalize_device(device) + if device.type == "meta": + defer_param_init = True + device = canonicalize_device(None) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + self.device: torch.device = device + + # Initialize parameters if needed + dtype = canonicalize_dtype(dtype) + weight = torch.empty( + self._shape, + device="meta", + dtype=dtype, + ) + bias = torch.empty( + self._shape, + device="meta", + dtype=dtype, + ) + weight = torch.nn.Parameter(weight) + bias = torch.nn.Parameter(bias) + self.weight: torch.nn.Parameter + self.bias: torch.nn.Parameter + self.register_parameter("weight", weight) + self.register_parameter("bias", bias) + if not defer_param_init: + self.reset_parameters() + + # Number of SMs to exclude when launching CUDA kernels + self._sm_margins: dict[str, int] + if isinstance(sm_margin, dict): + + def getenv(name: str) -> int: + return int(os.getenv(name, "0")) + + self._sm_margins = { + "forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")), + "backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")), + "inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")), + } + else: + + def getenv(name: str) -> int: + return int(os.getenv(name, str(sm_margin))) + + self._sm_margins = { + "forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"), + "backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"), + "inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"), + } + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Make sure parameter is initialized + weight = self.weight + bias = self.bias + if weight.device.type != "cuda": + weight = torch.empty_like(weight, device=self.device) + else: + weight = weight.to(device=self.device) + if bias.device.type != "cuda": + bias = torch.empty_like(bias, device=self.device) + else: + bias = bias.to(device=self.device) + + # Initialize values + if self.zero_centered_gamma: + torch.nn.init.zeros_(weight) + else: + torch.nn.init.ones_(weight) + torch.nn.init.zeros_(bias) + + # Save updated parameter + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + self.weight = weight + self.bias = bias + + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) + if self.weight.device.type == "meta" or self.bias.device.type == "meta": + self.reset_parameters() + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check tensor dims + input_dims = tuple(input_.size()) + if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={self._shape}) are not compatible" + ) + + # Check input tensors + inner_dim = math.prod(self._shape) + device = self.device + dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if isinstance(w, QuantizedTensor): + w = w.dequantize() + if isinstance(b, QuantizedTensor): + b = b.dequantize() + + # Check if backward pass is needed + requires_grad = ctx.requires_grad + + # Check if FP8 is enabled + with_fp8_output = ( + FP8GlobalStateManager.is_fp8_enabled() + and next_op is not None + and next_op.num_fp8_scales("input") > 0 + ) + output_fp8_meta = None + if with_fp8_output: + output_fp8_meta = next_op.get_fp8_meta("input") + + # Compute layer norm + y = None + means = None + rstdevs = None + sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) + args = ( + x, + w, + b, + self.eps, + output_fp8_meta[fp8_meta_key], + 0, # fp8_meta_index + fp8_dtype, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + data, means, rstdevs = layernorm_fwd_fp8(*args) + else: + data = layernorm_fwd_fp8_inf(*args) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + args = ( + x, + w, + b, + self.eps, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + y, means, rstdevs = layernorm_fwd(*args) + else: + y = layernorm_fwd_inf(*args) + + # Save state for backward pass + if requires_grad: + ctx.save_for_backward(x, means, rstdevs) + ctx.dtype = dtype + ctx.has_prev_op = prev_op is not None + + # Reshape output tensor + out = reshape(y, input_dims) + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + x, means, rstdevs = ctx.saved_tensors + + # Check input tensors + inner_dim = x.size(-1) + device = self.device + dtype = ctx.dtype + dy = reshape(grad_output, x.size(), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + if isinstance(w, QuantizedTensor): + w = w.dequantize() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + + # Compute layer norm backward pass + dx, dw, db = layernorm_bwd( + dy, + x, + means, + rstdevs, + w, + self._sm_margins["backward"], + self.zero_centered_gamma, + ) + + # Clear saved tensors if possible + if ctx.has_prev_op: + clear_tensor_data(x) + clear_tensor_data(means) + clear_tensor_data(rstdevs) + + # Reshape results + grad_input = reshape(dx, grad_output.size()) + grad_weight = reshape(dw, self._shape) + grad_bias = reshape(db, self._shape) + return grad_input, (grad_weight, grad_bias) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py new file mode 100644 index 0000000000..313b6e5583 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for quantization.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ..op import BasicOperation, OperationContext + + +class Quantize(BasicOperation): + """Quantize tensor data + + Uses FP8 recipe from `fp8_autocast` context. When called outside + of an `fp8_autocast` context, this is an identity operation. + + Parameters + ---------- + forward: bool, default = `True` + Perform quantization in forward pass + backward: bool, default = `False` + Perform quantization in backward pass + + """ + + def __init__( + self, + forward: bool = True, + backward: bool = False, + ) -> None: + super().__init__() + self._quantize_forward = forward + self._quantize_backward = backward + + def num_fp8_scales(self, mode: str) -> int: + if mode == "input" and self._quantize_forward: + return 1 + if mode == "grad_output" and self._quantize_backward: + return 1 + return 0 + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + quantize_forward = fp8_enabled and self._quantize_forward + quantize_backward = fp8_enabled and self._quantize_backward + + # Quantize if needed + out = input_ + if quantize_forward and not isinstance(out, QuantizedTensor): + fp8_meta = self.get_fp8_meta("input") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + out = Float8Tensor.to_float8( + out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + + ctx.quantize_backward = quantize_backward + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + grad_input = grad_output + if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): + fp8_meta = self.get_fp8_meta("grad_output") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input = Float8Tensor.to_float8( + grad_input, + fp8_meta=fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py new file mode 100644 index 0000000000..4f0e2ddc22 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -0,0 +1,289 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusable operation for RMSNorm.""" + +from __future__ import annotations +from collections.abc import Iterable +import math +import os +from typing import Optional + +import torch + +from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd +from ...cpp_extensions import ( + rmsnorm_fwd_fp8, + rmsnorm_fwd_fp8_inf, + rmsnorm_fwd_inf, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_autocast_dtype, reshape + + +class RMSNorm(BasicOperation): + r"""Root Mean Square Layer Normalization + + Applies Root Mean Square Layer Normalization over a mini-batch of + inputs as described in the paper + `Root Mean Square Layer Normalization `__ + + .. math:: + y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + + :math:`\gamma` is a learnable affine transform parameter that + matches the inner-most dimensions of the input tensor. + + Parameters + ---------- + normalized_shape: int or iterable of int + Inner dimensions of input tensor + eps : float, default = 1e-5 + A value added to the denominator for numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + zero_centered_gamma : bool, default = 'False' + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + + sm_margin: int, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + """ + + def __init__( + self, + normalized_shape: Iterable[int] | int, + *, + eps: float = 1e-5, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + zero_centered_gamma: bool = False, + sm_margin: int = 0, + ) -> None: + super().__init__() + self.eps: float = eps + self.zero_centered_gamma: bool = zero_centered_gamma + + # Parameter shape + if not isinstance(normalized_shape, Iterable): + normalized_shape = (normalized_shape,) + else: + normalized_shape = tuple(normalized_shape) + self._shape: tuple[int, ...] = normalized_shape + + # Parameter device + defer_param_init = False + device = canonicalize_device(device) + if device.type == "meta": + defer_param_init = True + device = canonicalize_device(None) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + self.device: torch.device = device + + # Initialize parameters if needed + weight = torch.empty( + self._shape, + device="meta", + dtype=canonicalize_dtype(dtype), + ) + weight = torch.nn.Parameter(weight) + self.weight: torch.nn.Parameter + self.register_parameter("weight", weight) + if not defer_param_init: + self.reset_parameters() + + # Number of SMs to exclude when launching CUDA kernels + self._sm_margins: dict[str, int] + if isinstance(sm_margin, dict): + + def getenv(name: str) -> int: + return int(os.getenv(name, "0")) + + self._sm_margins = { + "forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")), + "backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")), + "inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")), + } + else: + + def getenv(name: str) -> int: + return int(os.getenv(name, str(sm_margin))) + + self._sm_margins = { + "forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"), + "backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"), + "inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"), + } + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Make sure parameter is initialized + weight = self.weight + if weight.device.type != "cuda": + weight = torch.empty_like(weight, device=self.device) + else: + weight = weight.to(device=self.device) + + # Initialize values + if self.zero_centered_gamma: + torch.nn.init.zeros_(weight) + else: + torch.nn.init.ones_(weight) + + # Save updated parameter + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + self.weight = weight + + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) + if self.weight.device.type == "meta": + self.reset_parameters() + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check tensor dims + input_dims = tuple(input_.size()) + if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={self._shape}) are not compatible" + ) + + # Check input tensors + inner_dim = math.prod(self._shape) + device = self.device + dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if isinstance(w, QuantizedTensor): + w = w.dequantize() + + # Check if backward pass is needed + requires_grad = ctx.requires_grad + + # Check if FP8 is enabled + with_fp8_output = ( + FP8GlobalStateManager.is_fp8_enabled() + and next_op is not None + and next_op.num_fp8_scales("input") > 0 + ) + output_fp8_meta = None + if with_fp8_output: + output_fp8_meta = next_op.get_fp8_meta("input") + + # Compute RMSNorm + y = None + rstdevs = None + sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) + args = ( + x, + w, + self.eps, + output_fp8_meta[fp8_meta_key], + 0, # fp8_meta_index + fp8_dtype, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + data, rstdevs = rmsnorm_fwd_fp8(*args) + else: + data = rmsnorm_fwd_fp8_inf(*args) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + args = ( + x, + w, + self.eps, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + y, rstdevs = rmsnorm_fwd(*args) + else: + y = rmsnorm_fwd_inf(*args) + + # Save state for backward pass + if requires_grad: + ctx.save_for_backward(x, rstdevs) + ctx.dtype = dtype + ctx.has_prev_op = prev_op is not None + + # Reshape output tensor + out = reshape(y, input_dims) + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + x, rstdevs = ctx.saved_tensors + + # Check input tensors + inner_dim = x.size(-1) + device = self.device + dtype = ctx.dtype + dy = reshape(grad_output, x.size(), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + if isinstance(w, QuantizedTensor): + w = w.dequantize() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + + # Compute RMSNorm backward pass + dx, dw = rmsnorm_bwd( + dy, + x, + rstdevs, + w, + self._sm_margins["backward"], + self.zero_centered_gamma, + ) + + # Clear saved tensors if possible + if ctx.has_prev_op: + clear_tensor_data(x) + clear_tensor_data(rstdevs) + + # Reshape results + grad_input = reshape(dx, grad_output.size()) + grad_weight = reshape(dw, self._shape) + return grad_input, (grad_weight,) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index be37ab8976..bbfb9416fc 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -57,12 +57,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): # pylint: disable=unused-argument @staticmethod def forward( - func_ctx: torch.autograd.function.FunctionCtx, + func_ctx: Optional[torch.autograd.function.FunctionCtx], input_: torch.Tensor, forward_ops: list[tuple[FusibleOperation, list[int]]], backward_ops: list[tuple[FusibleOperation, list[int]]], basic_ops: list[BasicOperation], basic_op_kwargs: list[dict[str, Any]], + is_grad_enabled: bool, num_params: int, num_extra_inputs: int, *params_and_extra_inputs: torch.nn.Parameter, @@ -120,10 +121,20 @@ def forward( # Apply forward ops x = input_ - requires_grad = x.requires_grad + requires_grad = is_grad_enabled and x.requires_grad extra_outputs = [None for _ in range(len(basic_ops))] for op, basic_op_idxs in forward_ops: + # Check if backward op is required + if is_grad_enabled: + if not requires_grad: + requires_grad = any(param.requires_grad for param in op.parameters()) + if not requires_grad: + requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) + for idx in basic_op_idxs: + basic_op_ctxs[idx].requires_grad = requires_grad + x.requires_grad_(requires_grad=requires_grad) + # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] @@ -138,18 +149,12 @@ def forward( basic_op_next_ops=next_ops, basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) + x.requires_grad_(requires_grad=requires_grad) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): + for y in ys: + y.requires_grad_(requires_grad=requires_grad) extra_outputs[idx] = ys - # Check if backward op is required - if not requires_grad: - requires_grad = any(param.requires_grad for param in op.parameters()) - if not requires_grad: - requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) - for idx in basic_op_idxs: - basic_op_ctxs[idx]._requires_grad = requires_grad - x.requires_grad_(requires_grad=requires_grad) - # Flatten list of extra outputs extra_outputs_flat = [] for idx, ys in enumerate(extra_outputs): @@ -163,25 +168,28 @@ def forward( ) extra_outputs_flat.extend(ys) - # Flatten list of saved tensors - to_save = [] - for ctx in basic_op_ctxs: - range_start = len(to_save) - if ctx.to_save is not None: - to_save.extend(ctx.to_save) - range_end = len(to_save) - ctx.to_save = None - ctx._saved_tensors_range = (range_start, range_end) - func_ctx.save_for_backward(*to_save) - - # Other context for backward pass - func_ctx.backward_ops = backward_ops - func_ctx.basic_ops = basic_ops - func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.num_params = num_params - func_ctx.num_extra_inputs = num_extra_inputs - func_ctx.num_extra_outputs = len(extra_outputs_flat) - func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + # Save context for backward pass + if is_grad_enabled: + + # Flatten list of saved tensors + to_save = [] + for ctx in basic_op_ctxs: + range_start = len(to_save) + if ctx.to_save is not None: + to_save.extend(ctx.to_save) + range_end = len(to_save) + ctx.to_save = None + ctx._saved_tensors_range = (range_start, range_end) + func_ctx.save_for_backward(*to_save) + + # Other context + func_ctx.backward_ops = backward_ops + func_ctx.basic_ops = basic_ops + func_ctx.basic_op_ctxs = basic_op_ctxs + func_ctx.num_params = num_params + func_ctx.num_extra_inputs = num_extra_inputs + func_ctx.num_extra_outputs = len(extra_outputs_flat) + func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if extra_outputs_flat: return x, *extra_outputs_flat @@ -224,7 +232,7 @@ def backward( for op, basic_op_idxs in backward_ops: # Stop if no more gradients are required - if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs): + if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs): dx = None break @@ -282,6 +290,7 @@ def backward( None, # backward_ops None, # basic_ops None, # basic_op_kwargs + None, # is_grad_enabled None, # num_params None, # num_extra_inputs *grad_params_flat, @@ -373,14 +382,23 @@ def __call__( params = [param for op in self._basic_ops for param in op.parameters()] # Fuser forward pass - return _OperationFuserAutogradFunction.apply( + is_grad_enabled = torch.is_grad_enabled() + if is_grad_enabled: + forward_func = _OperationFuserAutogradFunction.apply + args = [] + else: + forward_func = _OperationFuserAutogradFunction.forward + args = [None] + args += ( input, self._forward_ops, self._backward_ops, self._basic_ops, basic_op_kwargs, + is_grad_enabled, len(params), self._num_extra_inputs, *params, *extra_inputs, ) + return forward_func(*args) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 9e4963d52a..0bb6f25db8 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -43,7 +43,7 @@ class OperationContext: _saved_tensors_range: Optional[tuple[int, int]] = None # Whether backward pass is required - _requires_grad: bool = False + requires_grad: bool = True def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None: """Register tensors to be saved for the backward function