diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py
index e5df485eda..bfd7bf8471 100644
--- a/benchmarks/attention/benchmark_attention.py
+++ b/benchmarks/attention/benchmark_attention.py
@@ -11,9 +11,7 @@
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
- _is_flash_attention_supported,
- _is_fused_attention_supported,
- _is_unfused_attention_supported,
+ _get_attention_backends,
_run_dot_product_attention,
)
@@ -29,8 +27,6 @@
workspace_opt = True
# QKV memory layout
qkv_layout = "bshd_bshd_bshd"
-# sliding window attention
-swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
@@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
- swa,
pad_between_seqs,
is_training,
)
@@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
- swa,
pad_between_seqs,
is_training,
)
@@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
- swa,
pad_between_seqs,
is_training,
)
@@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
- swa,
pad_between_seqs,
is_training,
)
@@ -205,13 +197,15 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
- fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
+ available_backends, fused_attn_backends = _get_attention_backends(
config,
- dtype,
+ qkv_dtype=dtype,
qkv_layout=qkv_layout,
+ window_size=config.window_size,
+ pad_between_seqs=pad_between_seqs,
)
- fused_attn_supported = fused_attn_supported and not swa
- flash_attn_supported = _is_flash_attention_supported(config)
+ flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
+
print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
index cd8ab85ba2..85ce01079c 100644
--- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
+++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
@@ -6,7 +6,6 @@
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
-from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
@@ -22,7 +21,7 @@
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
- _set_cuda_rng_state(_cuda_rng_state)
+ torch.cuda.set_rng_state(_cuda_rng_state)
def _run_dot_product_attention(
@@ -40,7 +39,7 @@ def _run_dot_product_attention(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
inp = torch.randn(
- [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
+ [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk],
dtype=dtype,
device="cuda",
)
@@ -51,7 +50,7 @@ def _run_dot_product_attention(
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn(
- [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
+ [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v],
dtype=dtype,
device="cuda",
)
@@ -80,7 +79,7 @@ def _run_dot_product_attention(
block = DotProductAttention(
config.num_heads,
- config.head_dim,
+ config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd",
attention_dropout=config.dropout_p,
@@ -89,6 +88,8 @@ def _run_dot_product_attention(
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
+ attn_mask_type="no_mask",
+ window_size=(-1, -1),
).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
@@ -103,6 +104,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
+ window_size=(-1, -1),
)
out.backward(out_grad)
@@ -116,6 +118,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
+ window_size=(-1, -1),
)
out.backward(out_grad)
@@ -133,6 +136,7 @@ def _run_dot_product_attention(
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
+print()
print("Run with arbitrary mask:")
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
@@ -140,4 +144,6 @@ def _run_dot_product_attention(
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
+
+print()
print("Test passed!")
diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb
index 515f420790..27017b4773 100644
--- a/docs/examples/attention/attention.ipynb
+++ b/docs/examples/attention/attention.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "8ae3bc43",
+ "id": "040f466a",
"metadata": {},
"source": [
"# Attention Is All You Need!\n",
@@ -23,7 +23,7 @@
},
{
"cell_type": "markdown",
- "id": "47421c01",
+ "id": "89a7d849",
"metadata": {},
"source": [
"## 1. Attention Backends\n",
@@ -71,7 +71,7 @@
},
{
"cell_type": "markdown",
- "id": "e52f60f0",
+ "id": "c90a2573",
"metadata": {},
"source": [
"### 1.1 Flash vs. Non-Flash\n",
@@ -85,30 +85,30 @@
"- **Recomputation:** The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.\n",
"\n",
"
\n",
- "Note \n",
+ "Note: \n",
" \n",
- "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
+ "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
"
\n"
]
},
{
"cell_type": "markdown",
- "id": "bb909ac4",
+ "id": "b5ce567d",
"metadata": {},
"source": [
"### 1.2 flash-attention\n",
"\n",
"The flash-attention backend, available only in PyTorch, is a module wrapped around the public `flash-attn` package [[3]](https://github.com/Dao-AILab/flash-attention). \n",
"\n",
- "The flash-attention backend supports `flash-attn`'s features as they are released, and to facilitate the use of `flash-attn`, flash-attention also offers a few functionalities such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask. Please see `transformer_engine.pytorch.attention.FlashAttention` for more details.\n",
+ "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n",
"\n",
- "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
+ "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
"\n",
- "To understand `flash-attn`'s performance, please refer to their [benchmarks](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
+ "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
"\n",
"### 1.3 cuDNN Attention\n",
"\n",
- "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as `flash-attn` is.\n",
+ "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
"\n",
"\n",
" \n",
@@ -153,14 +153,14 @@
"
\n",
"
\n",
"\n",
- "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and `flash-attn` 2.4.2,\n",
+ "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
"\n",
"- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n",
"- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
- "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three without transposes (see Section 3.1 for more details).\n",
+ "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n",
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
- "- flash-attention supports sliding window attention, and cuDNN attention does not.\n",
- "- flash-attention uses bottom right diagonal for `causal` mask in cross attention, and cuDNN attention uses top left (see `flash-attn`'s [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)).\n",
+ "- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n",
+ "- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n",
"- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
"\n",
"To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
@@ -169,7 +169,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9a380859",
+ "id": "c5b8e3d7",
"metadata": {},
"outputs": [],
"source": [
@@ -184,25 +184,25 @@
},
{
"cell_type": "code",
- "execution_count": 2,
- "id": "0584bb01",
+ "execution_count": 1,
+ "id": "50852cb5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Device 0: NVIDIA H100 PCIe GPU, sm90 compute capability, 79.1GB memory\n",
+ "Device 0: NVIDIA H100 80GB HBM3 GPU, sm90 compute capability, 79.1GB memory\n",
"Running test_0 with cuDNN attention and flash-attention...\n",
"Running test_1 with cuDNN attention and flash-attention...\n",
"Running test_2 with cuDNN attention...\n",
"Running test_3 with cuDNN attention and flash-attention...\n",
"\n",
" cuDNN fwd+bwd (ms) flash-attn fwd+bwd (ms) cuDNN vs flash speedup\n",
- "test_0 0.0638 0.0858 1.3454\n",
- "test_1 0.5415 0.7496 1.3842\n",
- "test_2 1.2302 0.0000 0.0000\n",
- "test_3 12.0122 19.0716 1.5877\n"
+ "test_0 0.0340 0.0468 1.3786\n",
+ "test_1 0.3664 0.5850 1.5968\n",
+ "test_2 0.9332 0.0000 0.0000\n",
+ "test_3 7.4875 11.8879 1.5877\n"
]
}
],
@@ -212,7 +212,7 @@
},
{
"cell_type": "markdown",
- "id": "45e53fc9",
+ "id": "9a615119",
"metadata": {},
"source": [
"## 2. Backend Selection\n",
@@ -253,35 +253,35 @@
},
{
"cell_type": "markdown",
- "id": "6dfeade3",
+ "id": "e6c0f3f0",
"metadata": {},
"source": [
"### 2.1 Debug Information\n",
"\n",
- "To find out which backend is being used during runtime, users can turn on these debugging flags. Logging is done using the `logging` package.\n",
+ "To find out which backend is being used during runtime, we have the following two debugging flags. Logging is done by using the `logging` package.\n",
"```\n",
"NVTE_DEBUG = 0/1 # disables/enables debugging\n",
"NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages\n",
"```\n",
"\n",
- "Note\n",
+ "Note:\n",
" \n",
- "These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.\n",
+ "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n",
"
"
]
},
{
"cell_type": "markdown",
- "id": "7e3b7981",
+ "id": "16660323",
"metadata": {},
"source": [
- "The [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend was actually used during runtime."
+ "The example script [example_attention.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/example_attention.py) runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here `NVTE_DEBUG_LEVEL=1` allows us to find out which backend/sub-backend is used in runtime."
]
},
{
"cell_type": "code",
- "execution_count": 22,
- "id": "961c51d4",
+ "execution_count": 24,
+ "id": "906b8cf1",
"metadata": {},
"outputs": [
{
@@ -293,7 +293,7 @@
"[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
"\n",
"Run flash-attention...\n",
- "[INFO | DotProductAttention]: Running with FlashAttention backend \n",
+ "[INFO | DotProductAttention]: Running with FlashAttention backend\n",
"\n",
"Test passed.\n"
]
@@ -305,16 +305,16 @@
},
{
"cell_type": "markdown",
- "id": "11bfbbd7",
+ "id": "8ca99461",
"metadata": {},
"source": [
- "To collect more information, users can turn on `NVTE_DEBUG_LEVEL=2`. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example, "
+ "`NVTE_DEBUG_LEVEL=2` allows us to find out more about the backend selection logic. Users are encouraged to double check the `config` and provide it to the Transformer Engine team if they would like to file a bug. "
]
},
{
"cell_type": "code",
- "execution_count": 25,
- "id": "162a2be1",
+ "execution_count": 23,
+ "id": "d3637094",
"metadata": {},
"outputs": [
{
@@ -323,16 +323,18 @@
"text": [
"\n",
"Run cuDNN attention...\n",
+ "[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': , 'cudnn_version': '9.3.0', 'qkv_type': , 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n",
"[DEBUG | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0\n",
+ "[DEBUG | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}\n",
+ "[DEBUG | DotProductAttention]: Selected backend = FusedAttention (sub-backend 1)\n",
"[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
- "[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': , 'flash_attn_version': , 'cudnn_version': '9.2.0'}\n",
- "[DEBUG | FusedAttnFunc ]: Running forward in torch.bfloat16\n",
- "[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n",
"\n",
"Run flash-attention...\n",
+ "[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': , 'cudnn_version': '9.3.0', 'qkv_type': , 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}\n",
"[DEBUG | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0\n",
- "[INFO | DotProductAttention]: Running with FlashAttention backend \n",
- "[DEBUG | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': , 'flash_attn_version': , 'cudnn_version': '9.2.0'}\n",
+ "[DEBUG | DotProductAttention]: Available backends = {FlashAttention=True, FusedAttention=False, UnfusedDotProductAttention=True}\n",
+ "[DEBUG | DotProductAttention]: Selected backend = FlashAttention\n",
+ "[INFO | DotProductAttention]: Running with FlashAttention backend\n",
"\n",
"Test passed.\n"
]
@@ -344,7 +346,7 @@
},
{
"cell_type": "markdown",
- "id": "779a51e6",
+ "id": "611d8fdb",
"metadata": {},
"source": [
"### 2.2 User Control\n",
@@ -392,28 +394,29 @@
},
{
"cell_type": "markdown",
- "id": "ccd5650d",
+ "id": "e60a2a3e",
"metadata": {},
"source": [
"## 3. Backend Support\n",
"\n",
- "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine's attention backends have the following support matrix.\n",
+ "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n",
"\n",
- "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Determinism Possible |\n",
- "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |\n",
- "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes (only for `bshd`,`sbhd`) | Yes |\n",
- "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes (only for `bshd`,`thd`) | Yes |\n",
- "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | No | Yes |\n",
+ "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
+ "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
+ "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
+ "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n",
+ "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
+ "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
]
},
{
"cell_type": "markdown",
- "id": "8439b389",
+ "id": "fbdcb327",
"metadata": {},
"source": [
"### 3.1 QKV Layout\n",
@@ -439,7 +442,7 @@
"**qkv_layout=thd_thd_thd:**\n",
"`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
"\n",
- "As of v1.7, Transformer Engine has the following support matrix.\n",
+ "As of v1.10, Transformer Engine has the following support matrix.\n",
"\n",
"\n",
" \n",
@@ -480,16 +483,16 @@
},
{
"cell_type": "markdown",
- "id": "0290f8e9",
+ "id": "855d9616",
"metadata": {},
"source": [
"### 3.2 Attention Mask\n",
"\n",
- "Transformer Engine supports 5 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n",
+ "Transformer Engine supports 7 mask types, and all the masks are defined as `True` masking out the corresponding element and `False` including the corresponding element in attention calculation.\n",
"\n",
- "- `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), `arbitrary`\n",
+ "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n",
"\n",
- "Different backends offer different support for attention mask. As of Transformer Engine 1.7,\n",
+ "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n",
"\n",
"\n",
" \n",
@@ -498,34 +501,25 @@
" Requires `attention_mask` | \n",
"
\n",
" \n",
- " flash-attention | \n",
- " `no_mask`, `causal`, `padding`, `padding_causal` | \n",
- " `no_mask`, `causal`: No | \n",
- "
\n",
- " \n",
- " `padding`, `padding_causal`: Yes if `cu_seqlens` not provided | \n",
- "
\n",
- " \n",
- " cuDNN attention | \n",
- " `no_mask`, `causal`, `padding`, `padding_causal` | \n",
- " `no_mask`, `causal`: No | \n",
+ " flash-attention | \n",
+ " `no_mask`, `causal` (self-attention),`padding`, `padding_causal` (self-attention),`causal_bottom_right`, `padding_causal_bottom_right` | \n",
+ " `no_mask`, `causal` `causal_bottom_right`: No`padding`, `padding_causal`, `padding_causal_bottom_right`: Yes if `cu_seqlens` not provided`arbitrary`: Yes | \n",
"
\n",
" \n",
- " \n",
- " `padding`, `padding_causal`: Yes if `cu_seqlens` not provided\n",
- " | \n",
+ " cuDNN attention | \n",
+ " `no_mask`, `causal`,`padding`, `padding_causal`,`causal_bottom_right`, `padding_causal_bottom_right` | \n",
+ " | \n",
"
\n",
" \n",
- " Framework-native attention | \n",
- " `no_mask`, `causal`, `arbitrary` | \n",
- " `no_mask`, `causal`: No | \n",
+ " Framework-native attention | \n",
+ " All (PyTorch)`no_mask`, `causal`, `padding` (Jax, PaddlePaddle) | \n",
"
\n",
" \n",
- " `arbitrary`: Yes | \n",
+ " | \n",
"
\n",
"
\n",
"\n",
- "**padding and padding_causal:** For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
+ "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
"\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
@@ -536,13 +530,13 @@
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n",
- "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.0. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n"
+ "**Arbitrary mask:** cuDNN does not support `Arbitrary` mask type as of v9.3. However, users can convert the mask to a regular `post_scale_bias` bias and achieve the same functionality. An example script for this conversion is [arbitrary_mask_to_post_scale_bias.py](https://raw.githubusercontent.com/NVIDIA/TransformerEngine/main/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py).\n"
]
},
{
"cell_type": "code",
- "execution_count": 6,
- "id": "b1b7cdd4",
+ "execution_count": 33,
+ "id": "a1f25a9b",
"metadata": {},
"outputs": [
{
@@ -550,27 +544,29 @@
"output_type": "stream",
"text": [
"Run with post_scale_bias:\n",
- "[DotProductAttention]: using cuDNN attention (sub-backend 1)\n",
+ "[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)\n",
+ "\n",
"Run with arbitrary mask:\n",
- "[DotProductAttention]: using unfused DPA\n",
+ "[INFO | DotProductAttention]: Running with UnfusedDotProductAttention backend\n",
+ "\n",
"Test passed!\n"
]
}
],
"source": [
- "!NVTE_DEBUG=1 python arbitrary_mask_to_post_scale_bias.py"
+ "!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python arbitrary_mask_to_post_scale_bias.py"
]
},
{
"cell_type": "markdown",
- "id": "e045c284",
+ "id": "dda4a589",
"metadata": {},
"source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
"\n",
"### 3.3 Attention Bias\n",
"\n",
- "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.\n",
+ "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n",
"\n",
"\n",
" \n",
@@ -617,25 +613,20 @@
},
{
"cell_type": "markdown",
- "id": "8b8a4e40",
+ "id": "a0702339",
"metadata": {},
"source": [
"### 3.4 FP8 Attention\n",
"\n",
"A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
"\n",
- "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.7. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
+ "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
"\n",
"- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
"\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n",
- "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`. This should result in the following print when the debug flags are turned on, `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2`.\n",
- "```\n",
- "[DEBUG | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0\n",
- "[DEBUG | FusedAttnFunc ]: Running forward in FP8\n",
- "[DEBUG | FusedAttnFunc ]: Running backward in torch.bfloat16\n",
- "```"
+ "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
]
}
],
diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py
index 2ed7303417..15022005a1 100644
--- a/docs/examples/attention/example_attention.py
+++ b/docs/examples/attention/example_attention.py
@@ -11,9 +11,7 @@
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
- _is_flash_attention_supported,
- _is_fused_attention_supported,
- _is_unfused_attention_supported,
+ _get_attention_backends,
_run_dot_product_attention,
)
@@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn,
qkv_layout,
workspace_opt,
- swa,
pad_between_seqs,
is_training,
)
@@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn,
qkv_layout,
workspace_opt,
- swa,
pad_between_seqs,
is_training,
)
@@ -94,13 +90,14 @@ def main():
models = ["test_0"]
for model in models:
config = model_configs[model]
- fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
+ available_backends, fused_attn_backends = _get_attention_backends(
config,
- dtype,
+ qkv_dtype=dtype,
qkv_layout=qkv_layout,
+ window_size=config.window_size,
+ pad_between_seqs=pad_between_seqs,
)
- fused_attn_supported = fused_attn_supported and not swa
- flash_attn_supported = _is_flash_attention_supported(config)
+ flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
example_attention(model, fused_attn_supported, flash_attn_supported)
diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py
index afc2081752..a1fd48513a 100644
--- a/tests/pytorch/fused_attn/test_fused_attn.py
+++ b/tests/pytorch/fused_attn/test_fused_attn.py
@@ -8,6 +8,7 @@
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional
+from contextlib import contextmanager
import pytest
import torch
@@ -108,6 +109,16 @@ def __init__(
self.window_size = window_size
+@contextmanager
+def logging_context(highest_level=logging.WARNING):
+ previous_level = logging.root.manager.disable
+ logging.disable(highest_level)
+ try:
+ yield
+ finally:
+ logging.disable(previous_level)
+
+
def _get_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
@@ -180,12 +191,13 @@ def test():
return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
- for i in range(3):
- os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
- _attention_backends["backend_selection_requires_update"] = True
- available_backends, fused_attention_backend = test()
- if fused_attention_backend == FusedAttnBackend[backends[i]]:
- fused_attn_backends.append(fused_attention_backend)
+ with logging_context():
+ for i in range(3):
+ os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
+ _attention_backends["backend_selection_requires_update"] = True
+ available_backends, fused_attention_backend = test()
+ if fused_attention_backend == FusedAttnBackend[backends[i]]:
+ fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends