Skip to content

Commit

Permalink
[PyTorch] move mask types to fprop (#402)
Browse files Browse the repository at this point in the history
* API change and some test fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* more test fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* ONNX fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixed fused attention tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* rm duplicate test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Aug 26, 2023
1 parent 94c57e4 commit 6aa1fcc
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 226 deletions.
252 changes: 124 additions & 128 deletions tests/pytorch/test_fused_attn.py

Large diffs are not rendered by default.

24 changes: 14 additions & 10 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ def __init__(self, hidden_size: int, num_attention_heads: int):
batch_first=False,
)

def forward(self, x, attn_mask=None):
output = self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
def forward(self, x, attention_mask=None):
output = self.mhsa(x, x, x, attn_mask=attention_mask, need_weights=False)
if isinstance(output, tuple):
output = output[0]
return output
Expand Down Expand Up @@ -461,7 +461,7 @@ def _test_e2e_selective_recompute(block, bs, dtype, config, recompute=False):

te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=recompute,
)
loss = te_out.sum()
Expand Down Expand Up @@ -526,13 +526,13 @@ def _test_e2e_full_recompute(block, bs, dtype, config, recompute=False):
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
te_inp_attn_mask,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
else:
te_out = block(
te_inp_hidden_states,
te_inp_attn_mask,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
)
loss = te_out.sum()
Expand Down Expand Up @@ -766,7 +766,7 @@ def test_gpt_accuracy(dtype, bs, model):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


def _test_mha_accuracy(block, bs, dtype, config, mask_type):
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states()

inp_hidden_states = torch.randn(
Expand All @@ -775,7 +775,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type):
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None

out = block(inp_hidden_states, inp_attn_mask)
forward_kwargs = {}
if te:
forward_kwargs["attn_mask_type"] = mask_type
forward_kwargs["attention_mask"] = inp_attn_mask

out = block(inp_hidden_states, **forward_kwargs)
loss = out.sum()
loss.backward()

Expand All @@ -801,7 +806,6 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
fuse_qkv_params=True,
qkv_weight_interleaved=False,
input_layernorm=False,
attn_mask_type=mask_type,
)
.to(dtype=dtype)
.cuda()
Expand All @@ -825,8 +829,8 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type)
te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type, te=True)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type, te=False)

# Check output.
if dtype == torch.float32:
Expand Down
29 changes: 11 additions & 18 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,6 @@ def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False):
self.fake_bf16_io = fake_bf16_io
if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)
Expand All @@ -793,7 +792,7 @@ def forward(self, inp, mask):
inp = inp.type(torch.bfloat16)

if self.fused_scaled_softmax:
ret = self.fused_scaled_softmax(inp, mask, self.scale)
ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale)
else:
if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, self.scale)
Expand Down Expand Up @@ -867,15 +866,14 @@ def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool):
# even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
attn_mask_type="causal",
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)

def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)
ret = self.fused_scaled_softmax(inp, mask, self.scale)
ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float)
return ret
Expand Down Expand Up @@ -1161,13 +1159,13 @@ def test_export_core_attention(
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value", "attention_mask"]
input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask)
inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type)

mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
Expand All @@ -1177,7 +1175,6 @@ def test_export_core_attention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
attn_mask_type=attn_mask_type,
).to(device='cuda')
do_export(model,
inp,
Expand All @@ -1193,9 +1190,8 @@ def test_export_core_attention(

test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
Expand Down Expand Up @@ -1269,7 +1265,6 @@ def test_export_multihead_attention(

model = te.MultiheadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
Expand All @@ -1278,8 +1273,8 @@ def test_export_multihead_attention(
return_bias=True,
).to(device='cuda')

inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type)
input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"]
output_names=["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
Expand Down Expand Up @@ -1347,13 +1342,13 @@ def test_export_transformer_layer(
num_attention_heads = 4

input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input", "attention_mask"]
input_names = ["input", "attention_mask", "self_attn_mask_type"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask)
inp = (input_tensor, attention_mask, attn_mask_type)

fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
Expand All @@ -1365,7 +1360,6 @@ def test_export_transformer_layer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
Expand Down Expand Up @@ -1547,17 +1541,16 @@ def test_export_gpt_generation(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')

# "Context phase": use full input sequence length
input_names = ["input"]
input_names = ["input", "attention_mask", "self_attn_mask_type"]
output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor,)
inp = (input_tensor, None, attn_mask_type)
do_export(model, inp, fname, use_fp8,
input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"},
Expand Down
10 changes: 6 additions & 4 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()

loss.backward()
Expand Down Expand Up @@ -217,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_

use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
Expand Down Expand Up @@ -253,7 +253,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):

use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
Expand Down Expand Up @@ -282,7 +282,9 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
encoder_output=te_inp_hidden_states
)
loss = te_out.sum()
loss.backward()
Expand Down
Loading

0 comments on commit 6aa1fcc

Please sign in to comment.