Skip to content

Commit

Permalink
Update documentation for 2.0 release (#1479)
Browse files Browse the repository at this point in the history
* Updated docs for TE 2.0

Signed-off-by: Przemek Tredak <[email protected]>

* Do not expose comm_gemm_overlap and cast_transpose_noop

Signed-off-by: Przemek Tredak <[email protected]>

* Made the figures larger

Signed-off-by: Przemek Tredak <[email protected]>

* Apply suggestions from code review

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Przemyslaw Tredak <[email protected]>

* Update quickstart_utils.py

Signed-off-by: Przemek Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change from review

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Przemyslaw Tredak <[email protected]>

---------

Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
4 people committed Feb 12, 2025
1 parent 2d058d6 commit e5cc6c2
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 55 deletions.
47 changes: 27 additions & 20 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ What is Transformer Engine?
.. overview-begin-marker-do-not-remove
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including
using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower
memory utilization in both training and inference. TE provides a collection of highly optimized
building blocks for popular Transformer architectures and an automatic mixed precision-like API that
can be used seamlessly with your framework-specific code. TE also includes a framework agnostic
C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better
performance with lower memory utilization in both training and inference. TE provides a collection
of highly optimized building blocks for popular Transformer architectures and an automatic mixed
precision-like API that can be used seamlessly with your framework-specific code. TE also includes a
framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8
support for Transformers.

As the number of parameters in Transformer models continues to grow, training and inference for
architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning
Expand All @@ -51,16 +52,16 @@ not available natively in frameworks today.

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer
layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support.
Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly
simplifying mixed precision training for users.
layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8
support. Modules provided by TE internally maintain scaling factors and other values needed for FP8
training, greatly simplifying mixed precision training for users.

Highlights
==========

* Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models
* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
* Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later

Examples
Expand Down Expand Up @@ -149,22 +150,22 @@ Installation
Pre-requisites
^^^^^^^^^^^^^^^^^^^^
* Linux x86_64
* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada
* NVIDIA Driver supporting CUDA 12.0 or later
* cuDNN 8.1 or later
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
* CUDA 12.1+ (CUDA 12.8+ for Blackwell)
* NVIDIA Driver supporting CUDA 12.1 or later
* cuDNN 9.3 or later

Docker
^^^^^^^^^^^^^^^^^^^^

The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. For example to use the NGC PyTorch container interactively,
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively,

.. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.01-py3
Where 23.10 is the container version. For example, 23.10 for the October 2023 release.
Where 25.01 (corresponding to January 2025 release) is the container version.

pip
^^^^^^^^^^^^^^^^^^^^
Expand All @@ -174,23 +175,29 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
This will automatically detect if any supported deep learning frameworks are installed and build
Transformer Engine support for them. To explicitly specify frameworks, set the environment variable
NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).

Alternatively, the package can be directly installed from `Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
Alternatively, the package can be directly installed from
`Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.

.. code-block:: bash
pip install transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be
explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]).
Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX
and PyTorch extensions.

From source
^^^^^^^^^^^
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_.

Compiling with FlashAttention-2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.
Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance.

It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.

Expand Down
5 changes: 3 additions & 2 deletions docs/api/c/layer_norm.rst → docs/api/c/fused_rope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
See LICENSE for license information.

layer_norm.h
fused_rope.h
============

.. doxygenfile:: layer_norm.h
.. doxygenfile:: fused_rope.h

12 changes: 8 additions & 4 deletions docs/api/c/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ directly from C/C++, without Python.
.. toctree::
:caption: Headers

transformer_engine.h <transformer_engine>
activation.h <activation>
cast.h <cast>
gemm.h <gemm>
fused_attn.h <fused_attn>
layer_norm.h <layer_norm>
rmsnorm.h <rmsnorm>
fused_rope.h <fused_rope>
gemm.h <gemm>
normalization.h <normalization>
padding.h <padding>
permutation.h <permutation>
recipe.h <recipe>
softmax.h <softmax>
transformer_engine.h <transformer_engine>
swizzle.h <swizzle>
transpose.h <transpose>
9 changes: 9 additions & 0 deletions docs/api/c/normalization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

normalization.h
===============

.. doxygenfile:: normalization.h
7 changes: 4 additions & 3 deletions docs/api/c/rmsnorm.rst → docs/api/c/padding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
See LICENSE for license information.

rmsnorm.h
============
padding.h
=========

.. doxygenfile:: padding.h

.. doxygenfile:: rmsnorm.h
10 changes: 10 additions & 0 deletions docs/api/c/permutation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

permutation.h
=============

.. doxygenfile:: permutation.h

10 changes: 10 additions & 0 deletions docs/api/c/recipe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

recipe.h
========

.. doxygenfile:: recipe.h

10 changes: 10 additions & 0 deletions docs/api/c/swizzle.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

swizzle.h
=========

.. doxygenfile:: swizzle.h

2 changes: 2 additions & 0 deletions docs/api/common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.Format

.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)

.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)
Binary file added docs/examples/E8M0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/examples/MXFP8_FP8_comparison_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/examples/MXFP8_FP8_comparison_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 53 additions & 6 deletions docs/examples/fp8_primer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"* E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and `nan`.\n",
"* E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf` and `nan`. The tradeoff of the increased dynamic range is lower precision of the stored values.\n",
"\n",
"<figure align=\"center\">\n",
"<figure align=\"center\" id=\"fig_1\">\n",
"<img src=\"fp8_formats.png\" width=\"60%\">\n",
"<figcaption> Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.</figcaption>\n",
"</figure>\n",
Expand Down Expand Up @@ -56,18 +56,63 @@
"As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration."
]
},
{
"cell_type": "markdown",
"id": "f03b58ed-71e8-422a-95be-35c1cc60c4e2",
"metadata": {},
"source": [
"## MXFP8 and block scaling\n",
"\n",
"NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: [MXFP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). \n",
"\n",
"### MXFP8 vs FP8\n",
"\n",
"The main difference between \"regular\" FP8 and MXFP8 lies in the granularity of the scaling. In FP8, each tensor has a single FP32 scaling factor, so all values in the tensor need to \"fit\" within the dynamic range of the FP8 datatype. This requires using the less precise E5M2 format to represent some tensors in the network (like gradients).\n",
"\n",
"MXFP8 addresses this by assigning a different scaling factor to each block of 32 [consecutive](#handling-transposes) values. This allows all values to be represented with the E4M3 datatype.\n",
"\n",
"<figure align=\"center\" id=\"fig_4\">\n",
"<img src=\"MXFP8_FP8_comparison_1.png\" width=\"100%\">\n",
"<figcaption> Figure 4: MXFP8 uses multiple scaling factors for a single tensor. The picture shows only 4 values per block for simplicity, but real MXFP8 has 32 values per block.</figcaption>\n",
"</figure>\n",
"\n",
"<figure align=\"center\" id=\"fig_5\">\n",
"<img src=\"MXFP8_FP8_comparison_2.png\" width=\"100%\">\n",
"<figcaption> Figure 5: Due to multiple scaling factors, tensor's dynamic range requirements are reduced and so E4M3 format can be used as far fewer elements get saturated to 0.</figcaption>\n",
"</figure>\n",
"\n",
"The second difference is the datatype used to store the scaling factors. FP8 uses FP32 (E8M23) while MXFP8 uses an 8-bit representation of a power of 2 (E8M0).\n",
"\n",
"<figure align=\"center\" id=\"fig_6\">\n",
"<img src=\"E8M0.png\" width=\"100%\">\n",
"<figcaption> Figure 6: Structure of the E8M0 datatype used for storing scaling factors in MXFP8.</figcaption>\n",
"</figure>\n",
"\n",
"### Handling transposes\n",
"\n",
"The forward and backward passes of linear layers involve multiple matrix multiplications with different reduction dimensions. Blackwell Tensor Cores require MXFP8 data to be \"consecutive\" over the reduction dimension, so MXFP8 training uses non-transposed and transposed MXFP8 tensors at different points. However, while transposing FP8 data is numerically trivial, transposing MXFP8 data requires requantization.\n",
"\n",
"To avoid loss of precision connected with this double quantization, Transformer Engine creates both regular and transposed copies of the tensor from the original high precision input.\n",
"\n",
"<figure align=\"center\" id=\"fig_7\">\n",
"<img src=\"linear_mxfp8.png\" width=\"80%\">\n",
"<figcaption> Figure 7: Linear layer in MXFP8. Calculating both forward and backward pass requires tensors quantized in both directions.</figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "cf5e0b0d",
"metadata": {},
"source": [
"## Using FP8 with Transformer Engine\n",
"\n",
"Transformer Engine library provides tools enabling easy to use training with FP8 datatype using delayed scaling strategy.\n",
"Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n",
"\n",
"### FP8 recipe\n",
"\n",
"[DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from `transformer_engine.common.recipe` module stores all of the required options for FP8 training - length of the amax history to use for scaling factor computation, FP8 data format etc."
"The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n",
"Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training."
]
},
{
Expand All @@ -77,10 +122,12 @@
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n",
"\n",
"fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")"
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"mxfp8_format = Format.E4M3 # E4M3 used everywhere\n",
"mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)"
]
},
{
Expand Down Expand Up @@ -341,7 +388,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Binary file added docs/examples/linear_mxfp8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 7 additions & 11 deletions docs/examples/quickstart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
# See LICENSE for license information.

import math
from typing import Callable, Optional
from typing import Optional
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type


def speedometer(
Expand Down Expand Up @@ -204,16 +203,13 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):


def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
import transformer_engine.pytorch.cpp_extensions as texcpp
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType

fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
input_type = TE_DType[inp.dtype]
meta = tex.FP8TensorMeta()
meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type)
ret = texcpp.cast_from_fp8(ret, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type, input_type)
scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type)
ret = quantizer(inp)
ret = ret.dequantize()
return ret
9 changes: 4 additions & 5 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ Prerequisites
.. _driver link: https://www.nvidia.com/drivers

1. Linux x86_64
2. `CUDA 12.0 <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 12.0 or later.
4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 <https://developer.nvidia.com/cudnn>`__ or later.
2. `CUDA 12.1+ (12.8+ for Blackwell support) <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 12.1 or later.
4. `cuDNN 9.3 <https://developer.nvidia.com/cudnn>`__ or later.

If the CUDA Toolkit headers are not available at runtime in a standard
installation path, e.g. within `CUDA_HOME`, set
Expand Down Expand Up @@ -76,7 +75,7 @@ Execute the following command to install the latest development build of Transfo
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`.

In order to install a specific PR, execute after changing NNN to the PR number:
In order to install a specific PR, execute (after changing NNN to the PR number):

.. code-block:: bash
Expand Down
19 changes: 15 additions & 4 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,24 @@ def __repr__(self) -> str:
@dataclass()
class MXFP8BlockScaling(Recipe):
"""
Use the current scaling factor strategy.
Use the MXFP8 scaling factor strategy.
In this strategy, tensors are scaled in blockwise fashion. Each group
of 32 consecutive values is scaled together using their own scaling
factor. The type of the scaling factor is E8M0 (8 bits of exponent,
0 bits of mantissa), equivalent to scaling by a power of 2.
Since the scaling happens in a particular direction (either rowwise
or columnwise), in this recipe the quantized tensor and its transpose
are not numerically equivalent. Due to this, when Transformer Engine
needs both the MXFP8 tensor and its transpose (e.g. to calculate both
forward and backward pass), during the quantization both versions are
computed from the high precision input to avoid double quantization
errors.
Parameters
----------
margin : int, default = 0
Margin for the scaling factor computation.
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
"""
Expand Down

0 comments on commit e5cc6c2

Please sign in to comment.