Skip to content

Commit

Permalink
Support packed input for FA (#302)
Browse files Browse the repository at this point in the history
* initial changes [wip]

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add padding mask support for FA

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Address review comments

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* rm causal mask from tests and add padding

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix some conflicts

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* conflicts

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add unpadding mask

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix padding mask

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* [wip] fix API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add packing and unpacking

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* docs fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix atomic_add bf16 torch.compile

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Generate non all True masks

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Lint fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix core attention export and FusedAttn filter

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix all ONNX tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Memory optimization

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Optimizations and caching fixes in torch.dynamo

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Bug fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Review comments

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Padding optimizations

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes and reviews

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Oct 4, 2023
1 parent d3157e2 commit 47ca514
Show file tree
Hide file tree
Showing 7 changed files with 442 additions and 163 deletions.
5 changes: 2 additions & 3 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,14 +612,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

block = _test_e2e_checkpointing_get_model(config, dtype)

for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
None,
)
loss = te_out.sum()
loss.backward()
Expand Down Expand Up @@ -650,7 +649,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
for _ in range(steps // 2):
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
None,
)
loss = te_out.sum()
loss.backward()
Expand Down
27 changes: 13 additions & 14 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
# See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
if attn_mask_type is None:
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_padding-no-mask"
attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
attn_mask_str = "_padding-mask" if use_mask and attn_mask_type == "padding" else attn_mask_str
attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
return attn_mask_str


Expand Down Expand Up @@ -986,14 +986,14 @@ def test_export_layernorm_mlp(
@skip_FP8
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [
(torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
])
def test_export_core_attention(
seed_default_rng,
Expand All @@ -1014,7 +1014,7 @@ def test_export_core_attention(
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask)

Expand Down Expand Up @@ -1043,9 +1043,8 @@ def test_export_core_attention(

test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
Expand Down
76 changes: 26 additions & 50 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()

if skip_wgrad:
_disable_wgrads(block)
Expand All @@ -193,18 +182,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()

if skip_wgrad:
_disable_wgrads(block)
Expand Down Expand Up @@ -233,18 +211,24 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)

if skip_wgrad:
_disable_wgrads(block)

use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()


def _test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()

te_inp_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5

if skip_wgrad:
_disable_wgrads(block)
Expand All @@ -261,18 +245,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool()
enc_dec_attn_mask = torch.rand(torch.Size([bs, 1, 1, config.seq_len])).cuda() > 0.5

if skip_wgrad:
_disable_wgrads(block)
Expand All @@ -282,7 +256,8 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_out = block(
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states
encoder_output=te_inp_hidden_states,
enc_dec_attn_mask=enc_dec_attn_mask,
)
loss = te_out.sum()
loss.backward()
Expand Down Expand Up @@ -541,13 +516,14 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="padding",
normalization=normalization,
)
.to(dtype=dtype)
.cuda()
)

_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e_bert(block, bs, dtype, config, fp8_recipe, skip_wgrad)


@pytest.mark.parametrize("dtype", param_types)
Expand Down
Loading

0 comments on commit 47ca514

Please sign in to comment.