-
Notifications
You must be signed in to change notification settings - Fork 337
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch] Add documentation for FP8 attention checkpointing (#1223)
* add extra_state change description for different TE versions Signed-off-by: Charlene Yang <[email protected]> * add FAQ page Signed-off-by: Charlene Yang <[email protected]> * update FAQ page Signed-off-by: Charlene Yang <[email protected]> * fix extra_state tests Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
5b89f1a
commit 2d87552
Showing
4 changed files
with
216 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
.. | ||
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
See LICENSE for license information. | ||
|
||
Frequently Asked Questions (FAQ) | ||
================================ | ||
|
||
FP8 checkpoint compatibility | ||
---------------------------- | ||
|
||
Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted. | ||
|
||
Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below. | ||
|
||
.. code-block:: python | ||
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init | ||
>>> with fp8_model_init(enabled=True): | ||
... mha = MultiheadAttention( | ||
... hidden_size=1024, | ||
... num_attention_heads=16, | ||
... bias=True, | ||
... params_dtype=torch.bfloat16, | ||
... input_layernorm=False, | ||
... fuse_qkv_params=True, | ||
... attention_type="self", | ||
... qkv_weight_interleaved=True, | ||
... ).to(dtype=torch.bfloat16, device="cuda") | ||
... | ||
>>> state_dict = mha.state_dict() | ||
>>> print(state_dict.keys()) | ||
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state']) | ||
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions. | ||
|
||
.. list-table:: | ||
|
||
* - **Version: <= 1.5** | ||
|
||
- Saves no FP8 metadata since FP8 attention is not supported | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Loads no FP8 metadata | ||
:> 1.5: Error: unexpected key | ||
* - **Version: 1.6, 1.7** | ||
|
||
- Saves FP8 metadata to `core_attention.fused_attention._extra_state` | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes | ||
:1.6, 1.7: Loads FP8 metadata from checkpoint | ||
:>= 1.8: Error: unexpected key | ||
* - **Version: >=1.8, <= 1.11** | ||
|
||
- Saves FP8 metadata to `core_attention._extra_state` | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes | ||
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by | ||
|
||
.. code-block:: python | ||
>>> state_dict["core_attention._extra_state"] = \ | ||
state_dict["core_attention.fused_attention._extra_state"] | ||
>>> del state_dict["core_attention.fused_attention._extra_state"] | ||
:>= 1.8: Loads FP8 metadata from checkpoint | ||
* - **Version: >=1.12** | ||
|
||
- Saves FP8 metadata to `core_attention._extra_state` | ||
- Loading behavior for checkpoints created by the following versions: | ||
|
||
:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes | ||
:>= 1.6: Loads FP8 metadata from checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ Transformer Engine documentation | |
|
||
installation | ||
examples/quickstart.ipynb | ||
faq | ||
|
||
.. toctree:: | ||
:hidden: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters