From 20b0473cd1a5e999d3f5996f5c45809410f45455 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:30:41 -0800 Subject: [PATCH] [PyTorch] Activation operations (#1164) * Add activation ops Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint warnings Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Update to use QuantizedTensor Signed-off-by: Tim Moon * Respect PyTorch autograd dtype Signed-off-by: Tim Moon * Rename CastFloat8 op to Quantize Signed-off-by: Tim Moon * Add support for fused dSwiGLU-cast-transpose Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 160 +++++++ .../pytorch/cpp_extensions/transpose.py | 39 ++ transformer_engine/pytorch/csrc/extensions.h | 6 + .../pytorch/csrc/extensions/pybind.cpp | 6 + .../pytorch/csrc/extensions/transpose.cpp | 69 ++++ .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/activation.py | 390 ++++++++++++++++++ 7 files changed, 671 insertions(+) create mode 100644 transformer_engine/pytorch/ops/basic/activation.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ec539e1f06..fd2832c1d4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1362,6 +1362,166 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_activation( + self, + *, + activation: str, + out_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Activation functions""" + + # Tensor dimensions + in_shape = list(out_shape) + if activation in ("geglu", "reglu", "swiglu"): + in_shape[-1] *= 2 + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref: torch.Tensor + if activation == "gelu": + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) + elif activation == "geglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "reglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "swiglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + else: + raise ValueError(f"Unexpected activation function ({activation})") + y_ref.backward(dy_ref) + + # Implementation with fusible operation + make_op = dict( + gelu=te_ops.GELU, + relu=te_ops.ReLU, + geglu=te_ops.GEGLU, + reglu=te_ops.ReGLU, + swiglu=te_ops.SwiGLU, + )[activation] + forward = te_ops.Sequential( + make_op(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) + def test_swiglu( + self, + *, + out_shape: Iterable[int] = (16, 16), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_output: bool, + fp8_grad_input: bool, + ): + + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + fp8 = fp8_output or fp8_grad_input + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + fp8_recipe = None + if fp8_grad_input: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + y_ref.backward(dy_ref) + + # Implementation with fusible operation + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=fp8_grad_input), + te_ops.SwiGLU(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ddc3b67e9e..188c03b27c 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -16,6 +16,7 @@ "fp8_cast_transpose_fused", "fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_dswiglu_cast_transpose_fused", "fp8_multi_cast_transpose_fused", "fp8_transpose_bgrad_fused", ] @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ) +def fp8_dswiglu_cast_transpose_fused( + grad_output: torch.Tensor, + inp: torch.Tensor, + *, + grad_input: torch.Tensor, + grad_input_transpose: torch.Tensor, + otype: tex.DType, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> None: + """Fused SwiGLU backward + FP8 cast + FP8 transpose""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + + # Launch kernel + return tex.fused_dswiglu_cast_transpose( + grad_output, + inp, + grad_input, + grad_input_transpose, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + otype, + **fp8_scales_offsets, + ) + + def fp8_multi_cast_transpose_fused( input_list: List[torch.Tensor], fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b039bf2d1b..3b49ece4a3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -210,6 +210,12 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset = 0, + int amax_offset = 0, int scale_inv_offset = 0); + void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 39679ed669..8856553c54 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, + "Fused SwiGLU backward + FP8 cast + FP8 transpose", + py::call_guard(), py::arg("grad_output"), py::arg("input"), + py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), + py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 56f6b56769..f373cdf83a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -196,6 +196,75 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset, + int amax_offset, int scale_inv_offset) { + using namespace transformer_engine; + + // Tensor dimensions + auto outer_dim = [](const at::Tensor& tensor) -> size_t { + return tensor.numel() / tensor.size(-1); + }; + const auto M = outer_dim(grad_output); + const auto N = static_cast(grad_output.size(-1)); + + // Check tensor dims + NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", + grad_output.dim()); + NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); + NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, + ", but found ", outer_dim(input)); + NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, + ", but found ", input.size(-1)); + NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", + grad_input.dim()); + NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", + M, ", but found ", outer_dim(grad_input)); + NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", + 2 * N, ", but found ", grad_input.size(-1)); + NVTE_CHECK(grad_input_transpose.dim() == 2, + "Expected grad input transpose tensor to have 2 dims, but found ", + grad_input_transpose.dim()); + NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, + "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", + grad_input_transpose.size(0)); + NVTE_CHECK(grad_input_transpose.size(1) == M, + "Expected grad input tensor to have outer dimension of ", M, ", but found ", + grad_input_transpose.size(1)); + + // Check tensor format + NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); + NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous"); + NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous"); + NVTE_CHECK(grad_input_transpose.is_contiguous(), + "Expected grad input transpose tensor to be contiguous"); + NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(), + "Expected grad output tensor and input tensor to have same dtype"); + NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte, + "Expected grad input tensor to be uint8 buffer"); + NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte, + "Expected grad input transpose tensor to be uint8 buffer"); + + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + + // Construct Transformer Engine tensors + auto dy_cu = makeTransformerEngineTensor(grad_output); + auto x_cu = makeTransformerEngineTensor(input); + auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); + + // Launch kernel + nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + void fused_multi_cast_transpose_base(std::vector input_list, std::vector scale_dptr_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 3dd8f64229..d6f4940c58 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,6 +4,7 @@ """Single tensor operations supported by the operation fuser.""" +from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py new file mode 100644 index 0000000000..a2e5a24a85 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -0,0 +1,390 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operations for activation functions.""" + +from __future__ import annotations +import abc +from typing import Optional + +import torch + +import transformer_engine_torch +from ...constants import TE_DType +from ...cpp_extensions import ( + geglu as tex_geglu, + gelu as tex_gelu, + reglu as tex_reglu, + relu as tex_relu, + swiglu as tex_swiglu, + fp8_dswiglu_cast_transpose_fused, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import clear_tensor_data, devices_match +from ..op import BasicOperation, OperationContext + + +class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): + r"""Apply activation function + + Activation functions are either element-wise unary functions or + variants of the gated linear unit (GLU). Recall that GLU is + computed by splitting the input tensor into chunks :math:`a` and + :math:`b` along the last dimension and computing + + .. math:: + \text{GLU}(a,b) = \sigma(a) * b + + .. warning:: + + Transformer Engine gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + @abc.abstractmethod + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + """Forward implementation + + Implementation from transformer_engine.pytorch.cpp_extensions. + + """ + + @abc.abstractmethod + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + """Backward implementation + + Implementation from transformer_engine_torch. + + """ + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = input_ + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if x.device.type != "cuda": + x = x.cuda() + if x.dtype != dtype: + x = x.to(dtype=dtype) + if not x.is_contiguous(): + x = x.contiguous() + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + with_fp8_output = False + output_fp8_meta = None + output_dtype = TE_DType[dtype] + output_fp8_scale_inv = None + if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: + with_fp8_output = True + fp8_meta = next_op.get_fp8_meta("input") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + output_fp8_meta = fp8_meta[fp8_meta_key] + output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) + + # Launch kernel + y = self._activation_forward_impl( + x, + output_fp8_meta, + 0, + output_dtype, + scale_inv=output_fp8_scale_inv, + ) + + # Check output tensor + if y.dim() != x.dim(): + y = y.reshape(list(x.shape[:-1]) + [-1]) + if with_fp8_output: + y = Float8Tensor( + data=y, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=output_dtype, + fp8_scale_inv=output_fp8_scale_inv, + dtype=dtype, + ) + + # Save state for backward pass + ctx.save_for_backward(x) + ctx.fp8_enabled = fp8_enabled + ctx.prev_op = prev_op + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: + dy = dy.to(device=x.device, dtype=x.dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Launch kernel + dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype]) + + # Check grad input tensor + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Clear input tensor if possible + if ctx.prev_op is not None: + clear_tensor_data(x) + + return dx, () + + +class GELU(_ActivationOperation): + r"""Gaussian Error Linear Unit + + This computes the "tanh" approximation to GELU: + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + See `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_gelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgelu(*args, **kwargs) + + +class ReLU(_ActivationOperation): + r"""Rectified linear unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_relu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.drelu(*args, **kwargs) + + +class GEGLU(_ActivationOperation): + r"""Gaussian error gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_geglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgeglu(*args, **kwargs) + + +class ReGLU(_ActivationOperation): + r"""Rectified gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{ReGLU}(a,b) = \max(a,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_reglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dreglu(*args, **kwargs) + + +class SwiGLU(_ActivationOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + `GLU Variants Improve Transformer`__ + and `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_swiglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dswiglu(*args, **kwargs) + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Tensor attributes + dtype = x.dtype + device = x.device + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, device) or dy.dtype != dtype: + dy = dy.to(device=device, dtype=dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Check if FP8 is enabled + with_fp8_grad_input = False + grad_input_fp8_meta = None + grad_input_dtype = TE_DType[dtype] + grad_input_fp8_scale_inv = None + if ( + ctx.fp8_enabled + and ctx.prev_op is not None + and ctx.prev_op.num_fp8_scales("grad_output") > 0 + ): + with_fp8_grad_input = True + fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + grad_input_fp8_meta = fp8_meta[fp8_meta_key] + grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) + + # Launch kernel + if with_fp8_grad_input: + # Fused with FP8 cast-transpose + input_dims = x.size() + flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]] + flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2] + dx = torch.empty(input_dims, dtype=torch.uint8, device=device) + dx_t = torch.empty( + (flat_input_dims[1], flat_input_dims[0]), + dtype=torch.uint8, + device=device, + ) + fp8_dswiglu_cast_transpose_fused( + dy.reshape(flat_output_dims), + x.reshape(flat_input_dims), + grad_input=dx.reshape(flat_input_dims), + grad_input_transpose=dx_t, + otype=grad_input_dtype, + fp8_meta=grad_input_fp8_meta, + fp8_meta_index=0, + scale_inv=grad_input_fp8_scale_inv, + ) + dx = Float8Tensor( + data=dx, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=grad_input_dtype, + fp8_scale_inv=grad_input_fp8_scale_inv, + dtype=dtype, + ) + dx._transpose = dx_t + dx._transpose_invalid = False + else: + # Standard impl + dx = self._activation_backward_impl(dy, x, TE_DType[dtype]) + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Note: This fails if op is preceeded by an identity op like Quantize(forward=False) + # # Clear input tensor if possible + # if ctx.prev_op is not None: + # clear_tensor_data(x) + + return dx, ()