Skip to content

Commit

Permalink
[PyTorch] Add documentation for FP8 attention checkpointing (#1223)
Browse files Browse the repository at this point in the history
* add extra_state change description for different TE versions

Signed-off-by: Charlene Yang <[email protected]>

* add FAQ page

Signed-off-by: Charlene Yang <[email protected]>

* update FAQ page

Signed-off-by: Charlene Yang <[email protected]>

* fix extra_state tests

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Oct 9, 2024
1 parent 5b89f1a commit 2d87552
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 68 deletions.
75 changes: 75 additions & 0 deletions docs/faq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
..
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

Frequently Asked Questions (FAQ)
================================

FP8 checkpoint compatibility
----------------------------

Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted.

Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below.

.. code-block:: python
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
... bias=True,
... params_dtype=torch.bfloat16,
... input_layernorm=False,
... fuse_qkv_params=True,
... attention_type="self",
... qkv_weight_interleaved=True,
... ).to(dtype=torch.bfloat16, device="cuda")
...
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.

.. list-table::

* - **Version: <= 1.5**

- Saves no FP8 metadata since FP8 attention is not supported
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Loads no FP8 metadata
:> 1.5: Error: unexpected key
* - **Version: 1.6, 1.7**

- Saves FP8 metadata to `core_attention.fused_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: Loads FP8 metadata from checkpoint
:>= 1.8: Error: unexpected key
* - **Version: >=1.8, <= 1.11**

- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by

.. code-block:: python
>>> state_dict["core_attention._extra_state"] = \
state_dict["core_attention.fused_attention._extra_state"]
>>> del state_dict["core_attention.fused_attention._extra_state"]
:>= 1.8: Loads FP8 metadata from checkpoint
* - **Version: >=1.12**

- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:>= 1.6: Loads FP8 metadata from checkpoint
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Transformer Engine documentation

installation
examples/quickstart.ipynb
faq

.. toctree::
:hidden:
Expand Down
165 changes: 104 additions & 61 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import pytest
import io
import os

from transformer_engine.pytorch.fp8 import (
fp8_autocast,
Expand Down Expand Up @@ -42,6 +43,7 @@
)
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
from test_numerics import reset_rng_states, dtype_tols

# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand Down Expand Up @@ -1004,84 +1006,125 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)

# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)


def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_dpa=fp8_enabled,
fp8_mha=False,
)

reset_rng_states()
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)

with fp8_model_init(enabled=True):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = block(hidden_states, is_first_microbatch=True)
loss = output.sum()
loss.backward()

# call state_dict()
sd = block.state_dict()

# check core_attention._extra_state
attn_extra_state = sd["self_attention.core_attention._extra_state"]
attn_extra_state.seek(0)
attn_extra_state = torch.load(attn_extra_state, map_location="cuda")

# add random core_attention.fused_attention._extra_state
# it should not be loaded or cause any 'unexpected key' errors
random_state = {"a": 1, "b": 2}
fused_attn_extra_state = io.BytesIO()
torch.save(random_state, fused_attn_extra_state)
sd["self_attention.core_attention.fused_attention._extra_state"] = fused_attn_extra_state

# save checkpoint
path = "./checkpoint.pt"
torch.save(sd, path)

# reinit the model
del block
with fp8_model_init(enabled=True):
block_new = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
FP8GlobalStateManager.reset()
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

with fp8_model_init(enabled=fp8_enabled):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block

block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()

if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)

param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())

_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()

del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)

for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)

assert not param_grads, "Oops!"

for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()

torch.cuda.synchronize()

if os.path.exists(path):
os.remove(path)

outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)

# load from checkpoint
block_new.load_state_dict(torch.load(path))

# check state_dict
sd_new = block_new.state_dict()
attn_extra_state_new = sd_new["self_attention.core_attention._extra_state"]
attn_extra_state_new.seek(0)
attn_extra_state_new = torch.load(attn_extra_state_new, map_location="cuda")
for k, v in attn_extra_state_new.items():
if k != "extra_fp8_variables":
assert torch.equal(v, attn_extra_state[k]), f"{k} is not equal"
else:
for ek, ev in attn_extra_state_new["extra_fp8_variables"].items():
assert ev == attn_extra_state["extra_fp8_variables"][ek], f"{ek} is not equal"
return outputs
43 changes: 36 additions & 7 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6790,10 +6790,10 @@ def __init__(
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove fused_attention._extra_state as a missing key
or an unexpected key when loading TransformerEngine checkpoints.
or an unexpected key when loading Transformer Engine checkpoints.
Please store FP8 metadata as DotProductAttention's _extra_state,
rather than FusedAttention's _extra_state. This hook will be
phased out in TransformerEngine 2.0.
phased out in Transformer Engine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "fused_attention._extra_state" in key:
Expand Down Expand Up @@ -7023,6 +7023,13 @@ class DotProductAttention(TransformerEngineBaseModule):
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. note::
Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
As the FP8 attention support expands from one backend to multiple backends, the location
of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).
Parameters
----------
num_attention_heads : int
Expand Down Expand Up @@ -7051,7 +7058,7 @@ class DotProductAttention(TransformerEngineBaseModule):
e.g. a different mask for training and inference.
1. For "`no_mask`", no attention mask is applied.
2. For "`causal`", "`causal_bottom_right`", or the causal mask in
"`padding_causal`" and "`padding_causal_bottom_right`", TransformerEngine
"`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
calculates and applies an upper triangular mask to the softmax input.
No user input is needed. Causal masks without the "`bottom_right`" appendix align
the diagonal line to the top left corner of the softmax matrix. With
Expand Down Expand Up @@ -7264,15 +7271,37 @@ def __init__(
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove core_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
when loading older Transformer Engine checkpoints. Will phase out
this hook in Transformer Engine 2.0.
"""
for key in incompatible_keys.missing_keys:
if "core_attention._extra_state" in key:
incompatible_keys.missing_keys.remove(key)

self.register_load_state_dict_post_hook(remove_extra_states_check)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""
This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
`core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
<https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
"""
fused_attn_key = False
dot_product_attn_key = False
for k in state_dict.keys():
if "core_attention.fused_attention._extra_state" in k:
fused_attn_key = True
if "core_attention._extra_state" in k:
dot_product_attn_key = True
if fused_attn_key and not dot_product_attn_key:
prefix = prefix + "fused_attention."
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def _checkpointed_attention_forward(
self,
attention_func: Callable,
Expand Down Expand Up @@ -7382,14 +7411,14 @@ def forward(
Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
FlashAttention over FusedAttention and over UnfusedDotProductAttention.
If FusedAttention is being used, users can also choose to switch to flash-attn's
implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
(default: 0), because of the performance differences between various versions of
flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT`
can be used to enable (:attr:`1`) or disable (:attr:`0`) the workspace related
optimizations in FusedAttention. When unset, TransformerEngine determines the code path
optimizations in FusedAttention. When unset, Transformer Engine determines the code path
based on its internal logic. These optimizations trade memory for performance
and should be used with care.
Expand Down

0 comments on commit 2d87552

Please sign in to comment.