Skip to content

Commit

Permalink
Miscellaneous fixes for core attention (#344)
Browse files Browse the repository at this point in the history
* miscellenous fixes

Signed-off-by: Charlene Yang <[email protected]>

* add back pytorch csrc extensions.h

Signed-off-by: Charlene Yang <[email protected]>

* add unit tests for dpa checkpointing

Signed-off-by: Charlene Yang <[email protected]>

* remove seqlen%32/64 checks for now

Signed-off-by: Charlene Yang <[email protected]>

* fix tests for core attn bias

Signed-off-by: Charlene Yang <[email protected]>

* add tests for changes regarding rng_state in aux_ctx_tensor

Signed-off-by: Charlene Yang <[email protected]>

* reuse rng tracker from numerics in fused attn; skip checkpointing if FAv2 in numerics

Signed-off-by: Charlene Yang <[email protected]>

* uncomment comments used for testing

Signed-off-by: Charlene Yang <[email protected]>

* fix pre/post scale bias

Signed-off-by: Charlene Yang <[email protected]>

* Update transformer_engine/pytorch/attention.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: cyanguwa <[email protected]>

* remove skipifs for FAv2 check after PR366

Signed-off-by: Charlene Yang <[email protected]>

* remove checkpointing tests for transformer layer; dpa tests still provide coverage

Signed-off-by: Charlene Yang <[email protected]>

* adjust random number range for tests

Signed-off-by: Charlene Yang <[email protected]>

* Add upper bound to FA version

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Check backend only when using FusedAttention

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* remove imports/variables related to FAv2 checks

Signed-off-by: Charlene Yang <[email protected]>

* further fix random number ranges for tests

Signed-off-by: Charlene Yang <[email protected]>

* fix variable referenced before assignment error

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
3 people authored Aug 11, 2023
1 parent a0f4435 commit cbfb8c6
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 111 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:

# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6"])
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.4"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
Expand Down
117 changes: 70 additions & 47 deletions tests/pytorch/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from pkg_resources import packaging
from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
Expand Down Expand Up @@ -58,47 +59,55 @@ def __init__(
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_dot_product_attention(dtype, bs, model):
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
"""Test DotProductAttention module with three backends,
FlashAttention, FusedAttention and UnfusedDotProductAttention"""

config = model_configs[model]

flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention")
if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FusedAttention")
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "UnfusedDotProductAttention")
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)

atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)

def _run_dot_product_attention(dtype, bs, config, backend):
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):

torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"

inp = 0.1 * torch.randn(
inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
else:
bias = None

block = (
DotProductAttention(
Expand All @@ -108,7 +117,7 @@ def _run_dot_product_attention(dtype, bs, config, backend):
attn_mask_type = config.attn_mask_type,
sequence_parallel = False,
tp_size = 1,
get_rng_state_tracker = None,
get_rng_state_tracker = get_dummy_cuda_rng_tracker,
tp_group = None,
layer_number = 1,
attention_type = "self"
Expand All @@ -118,7 +127,10 @@ def _run_dot_product_attention(dtype, bs, config, backend):
q = inp[:, :,0,:,:]
k = inp[:, :,1,:,:]
v = inp[:, :,2,:,:]
op = block(q, k, v)
op = block(q, k, v,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op.backward(op_grad)

return op, inp.grad
Expand All @@ -128,47 +140,50 @@ def _run_dot_product_attention(dtype, bs, config, backend):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer(dtype, bs, model):
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

config = model_configs[model]

flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention")
if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention")
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention")
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)

atol, rtol = (5e-1, 5e-1) if dtype == torch.bfloat16 else (5e-1, 5e-1)
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias":
assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)

def _run_transformer_layer(dtype, bs, config, backend):
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):

torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"

inp = 0.1 * torch.randn(
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()

sigma = 0.02
init_method = init_method_normal(sigma)
Expand All @@ -178,6 +193,11 @@ def _run_transformer_layer(dtype, bs, config, backend):
drop_path_rate = 0.0
drop_path_rates = [
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
if bias_type != "no_bias":
bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
dtype = dtype).cuda()
else:
bias = None

block = (
TransformerLayer(
Expand Down Expand Up @@ -215,8 +235,13 @@ def _run_transformer_layer(dtype, bs, config, backend):
.cuda()
)

op = block(inp)
op.backward(op_grad)
num_iters = 10
for i in range(num_iters):
op = block(inp,
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op.backward(op_grad)

return op, inp.grad

Expand Down Expand Up @@ -246,29 +271,28 @@ def find_factors(x):
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)

atol, rtol = 5e-1, 5e-1
atol, rtol = 5e-1, 5e-2
assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)

def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):

torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"

inp = 0.1 * torch.randn(
inp = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
inp.requires_grad=True
seqlens = torch.empty(bs, dtype = torch.int32).cuda()
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
config.seq_len, bs, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()

sigma = 0.02
init_method = init_method_normal(sigma)
Expand Down Expand Up @@ -342,14 +366,13 @@ def test_dpa_fp8(dtype, bs, model):
unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
dtype, bs, config, "UnfusedDotProductAttention")

atol, rtol = (5e-2, 1e-1)
atol, rtol = (2.5e-2, 2.5e-2)
assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol)

def _run_dpa_fp8(dtype, bs, config, backend):

torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"

Expand All @@ -361,9 +384,9 @@ def _run_dpa_fp8(dtype, bs, config, backend):
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = 0.001 * torch.randint(0, 200, (
bs * config.seq_len, config.num_attention_heads * config.head_dim
), dtype = dtype).cuda()
op_grad = 0.01 * torch.randn(
bs * config.seq_len, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
torch.save(op_grad, 'op_grad.pt')

fp8_recipe = recipe.DelayedScaling(
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint


seed = 1234
rng_str = "rng_state"
torch.manual_seed(seed)
Expand Down
22 changes: 21 additions & 1 deletion transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
#if (CUDNN_VERSION >= 8900)
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
#else
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
#endif
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
Expand Down Expand Up @@ -76,6 +82,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
#if (CUDNN_VERSION < 8901)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
#if (CUDNN_VERSION < 8900)
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible." << std::endl;
}
#endif
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
Expand Down
24 changes: 12 additions & 12 deletions transformer_engine/common/include/transformer_engine/fused_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
Expand Down Expand Up @@ -181,10 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
Expand Down Expand Up @@ -235,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
Expand Down Expand Up @@ -283,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked(
*
* Support Matrix:
\verbatim
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
Expand Down
Loading

0 comments on commit cbfb8c6

Please sign in to comment.