Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Update docs/example and benchmarks/ scripts #1075

Merged
merged 6 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_get_attention_backends,
_run_dot_product_attention,
)

Expand All @@ -29,8 +27,6 @@
workspace_opt = True
# QKV memory layout
qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
cyanguwa marked this conversation as resolved.
Show resolved Hide resolved
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
Expand Down Expand Up @@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand All @@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand All @@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand All @@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand Down Expand Up @@ -205,13 +197,15 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
Expand Down
16 changes: 11 additions & 5 deletions docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention

# Initialize RNG state
Expand All @@ -22,7 +21,7 @@
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)


def _run_dot_product_attention(
Expand All @@ -40,7 +39,7 @@ def _run_dot_product_attention(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
inp = torch.randn(
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk],
dtype=dtype,
device="cuda",
)
Expand All @@ -51,7 +50,7 @@ def _run_dot_product_attention(
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn(
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v],
dtype=dtype,
device="cuda",
)
Expand Down Expand Up @@ -80,7 +79,7 @@ def _run_dot_product_attention(

block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd",
attention_dropout=config.dropout_p,
Expand All @@ -89,6 +88,8 @@ def _run_dot_product_attention(
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
attn_mask_type="no_mask",
window_size=(-1, -1),
).to(dtype=dtype, device="cuda")

# Run a forward and backward pass
Expand All @@ -103,6 +104,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
window_size=(-1, -1),
)
out.backward(out_grad)

Expand All @@ -116,6 +118,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
window_size=(-1, -1),
)
out.backward(out_grad)

Expand All @@ -133,11 +136,14 @@ def _run_dot_product_attention(
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")

print()
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
print("Run with arbitrary mask:")
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")

torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)

print()
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
print("Test passed!")
Loading
Loading