Skip to content

Commit

Permalink
Improve softmax ONNX export tests (#370)
Browse files Browse the repository at this point in the history
* Add dynamically shaped input mask in test_export_softmax
* Fix test_softmax_mask_fn - use env. var `NVTE_ONNX_KVCACHE_MAX_SEQ_LEN` to control whether the test uses the default mask generation function or dynamic TRILU mask slicing.
* Change core_attention ONNX export test: use "no_mask" as attn mask type when testing `te.attention.DotProductAttention` w/o masking.
* Use ORT CUDA backend by default.
Signed-off-by: Neta Zmora <[email protected]>
  • Loading branch information
nzmora-nvidia authored Aug 11, 2023
1 parent ecd4f80 commit a0f4435
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def load_custom_ops(session_opts: ort.SessionOptions):
print("registered custom FP8 Q/DQ ops!")

"""Create an ONNX Runtime session for validation."""
kwargs = {}
kwargs = {"providers": ['CUDAExecutionProvider', 'CPUExecutionProvider']}
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
Expand Down Expand Up @@ -807,17 +807,17 @@ def forward(self, inp, mask):
precision = torch.bfloat16 if fake_bf16_io else precision

# Set dimensions (these are arbitrary).
in_features, hidden_size = 64, 256
batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
mask = None
input_names = ["input", "mask"]
inp_shape = [hidden_size, in_features, in_features, in_features]
inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [hidden_size, in_features, in_features]
inp_shape = [batch_size, seq_len_q, seq_len_k]
kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision)
probs = 0.5 * torch.ones(1, 1, seq_len_q, seq_len_k, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True)
Expand All @@ -832,8 +832,10 @@ def forward(self, inp, mask):
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask)

do_export(model, inp, fname, input_names=input_names)
dynamic_axes = {}
if mask is not None:
dynamic_axes = {"mask": {2:"seq_len_q", 3:"seq_len_k"}}
do_export(model, inp, fname, input_names=input_names, dynamic_axes=dynamic_axes)
te_outputs = te_infer(model, inp, is_fp8=False)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16:
Expand All @@ -845,16 +847,22 @@ def forward(self, inp, mask):
# Softmax kernel only supports FP16 or BF16!
@skip_FP8
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
def test_softmax_mask_fn(seed_default_rng, precision):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if fake_bf16_io else precision

class Test_Softmax(nn.Module):
def __init__(self, use_onnx_mask_fn: bool, fake_bf16_io: bool):
def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool):
super().__init__()
self.scale=1 # arbitrary value
self.fake_bf16_io=fake_bf16_io
self.scale = 1 # arbitrary value
self.fake_bf16_io = fake_bf16_io

if use_default_te_mask_fn:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = "0"
else:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{seq_len_q}"

# Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
# even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
Expand All @@ -873,22 +881,23 @@ def forward(self, inp, mask):
return ret

# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
mask = None
inp_shape = [hidden_size, in_features, in_features, in_features]
batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
assert seq_len_q == seq_len_k # This is a causal (TRILU) mask
inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
input_tensor = torch.randn(
*inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision)
inp = (input_tensor, mask)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)

# Compare the outputs of TE when using the default softmax mask
# to the TE outputs produced when using the ONNX-compatible causal mask.
model = Test_Softmax(use_onnx_mask_fn=False, fake_bf16_io=fake_bf16_io)
# This verifies that _get_onnx_export_causal_mask generates a correct mask.
model = Test_Softmax(use_default_te_mask_fn=True, fake_bf16_io=fake_bf16_io)
te_outputs_default_mask = te_infer(model, inp, is_fp8=True)
with te.onnx_export(True):
# ONNX export mode forces use of the ONNX-compatible causal mask.
model_onnx_mask = Test_Softmax(use_onnx_mask_fn=True, fake_bf16_io=fake_bf16_io)
model_onnx_mask = Test_Softmax(use_default_te_mask_fn=False, fake_bf16_io=fake_bf16_io)
te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True)
compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask,
atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking")
Expand Down Expand Up @@ -1129,14 +1138,14 @@ def test_export_layernorm_mlp(
@skip_FP8
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [
(torch.float32, False, None), # calls forward_torch_softmax
(torch.float32, True, None), # calls forward_torch_softmax
(torch.float16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(torch.float16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.float16, False, "padding"), # calls ScaledSoftmax
(torch.bfloat16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(torch.bfloat16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.bfloat16, False, "padding"), # calls ScaledSoftmax
(torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
])
def test_export_core_attention(
seed_default_rng,
Expand Down Expand Up @@ -1164,10 +1173,6 @@ def test_export_core_attention(
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"

if attn_mask_type is None:
attn_mask_type = 'causal'
input_names = ["query", "key", "value"]
inp = (query_layer, key_layer, value_layer)
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
Expand Down

0 comments on commit a0f4435

Please sign in to comment.