Skip to content

Commit

Permalink
[C/PyTorch/Jax] Add support for more bias shapes (#677)
Browse files Browse the repository at this point in the history
* added support for arbitrary bias shapes for fused_attn

Signed-off-by: Alp Dener <[email protected]>

* Fix linting

Signed-off-by: Alp Dener <[email protected]>

* Add b1ss/bhss/11ss bias shapes when not requiring dBias

Signed-off-by: Charlene Yang <[email protected]>

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* add bias_b/h to plan cache

Signed-off-by: Charlene Yang <[email protected]>

* fixed compile errors after PR653 merge

Signed-off-by: Alp Dener <[email protected]>

* updated JAX unittests for new bias shapes

Signed-off-by: Alp Dener <[email protected]>

* fixed mismatched mask type checking

Signed-off-by: Alp Dener <[email protected]>

* corrected skip condition

Signed-off-by: Alp Dener <[email protected]>

* fix selection logic for A100s

Signed-off-by: Charlene Yang <[email protected]>

* corrected skip checks for bias shapes

Signed-off-by: Alp Dener <[email protected]>

* resolved test issues but neginf with float16 is still problematic with JAX

Signed-off-by: Alp Dener <[email protected]>

* new bias shapes passing TE JAX CI for seqlen <= 512, seq_q == seq_kv and h_q == h_kv conditions

Signed-off-by: Alp Dener <[email protected]>

* TE/JAX fused attn tests for new bias shapes passing with neg_inf=-2**27 for Bfloat16 and -2**15 for Float16

Signed-off-by: Alp Dener <[email protected]>

* code style fixes and test parameter ID cleanup

Signed-off-by: Alp Dener <[email protected]>

* fixed incorrect skip condition for backward fused attn test

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: Alp Dener <[email protected]>
  • Loading branch information
cyanguwa and denera authored Feb 28, 2024
1 parent 0404095 commit b8eea8a
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 266 deletions.
221 changes: 143 additions & 78 deletions tests/jax/test_fused_attn.py

Large diffs are not rendered by default.

43 changes: 39 additions & 4 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
attn_bias_type: str,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
):
self.batch_size = batch_size
self.num_heads = num_heads
Expand All @@ -100,6 +101,7 @@ def __init__(
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape

def _is_fused_attention_supported(
config: ModelConfig,
Expand Down Expand Up @@ -379,6 +381,31 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
# mask, bias, bias_shape,
"no_mask", "post_scale_bias", bias_shape='11ss'),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0,
"no_mask", "post_scale_bias", bias_shape='1hss'),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='b1ss'),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"no_mask", "post_scale_bias", bias_shape='bhss'),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='1hss', alibi_type='custom'),
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
}

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
Expand Down Expand Up @@ -510,10 +537,13 @@ def _run_dot_product_attention(
window_size, attention_mask = None, None

alibi_slopes = None
if config.attn_bias_type == "alibi":
if config.alibi_type == "custom":
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
if config.bias_shape == "bhss":
alibi_slopes = torch.randn(
config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda")

# Create input tensors
dim_to_num = {
Expand All @@ -527,6 +557,7 @@ def _run_dot_product_attention(
'tg' : cu_seqlens_kv[-1],
'3' : 3,
'2' : 2,
'1' : 1,
}
inp = []
for i,layout in enumerate(qkv_layout.split('_')):
Expand Down Expand Up @@ -566,8 +597,12 @@ def _run_dot_product_attention(
if config.attn_bias_type in ['no_bias', 'alibi']:
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
shape = '_'.join(config.bias_shape)
shape = shape.replace('_s_s', '_sq_skv')
tensor_shape = [dim_to_num[j] for j in shape.split('_')]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != '1hss':
bias.requires_grad = False

# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
Expand Down Expand Up @@ -316,6 +317,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
FADescriptor_v1 descriptor{b, h,
hg, s_q,
s_kv, d,
bias_b, bias_h,
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
Expand Down Expand Up @@ -426,7 +428,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_backward_options.set_bias(bias);
sdpa_backward_options.set_dbias(dBias);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// are not supported for dbias calculation but they are
// supported for forward bias calculation
if ((bias_b == 1) && (bias_h == h)) {
sdpa_backward_options.set_dbias(dBias);
}
}

if (is_padding) {
Expand Down Expand Up @@ -541,7 +548,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(

if (is_bias) {
variant_pack[bias] = devPtrBias;
variant_pack[dBias] = devPtrdBias;
if ((bias_b == 1) && (bias_h == h)) {
variant_pack[dBias] = devPtrdBias;
} else {
variant_pack[dBias] = nullptr;
}
}

if (is_padding) {
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/common/fused_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ struct FADescriptor_v1 {
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
std::int64_t bias_b;
std::int64_t bias_h;
float attnScale;
bool isTraining;
float dropoutProbability;
Expand All @@ -112,11 +114,12 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t tensor_type;

bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d,
return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type)
< std::tie(
rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
rhs.bias_b, rhs.bias_h,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ enum NVTE_QKV_Layout {
};

/*! \enum NVTE_QKV_Layout_Group
* \brief QKV layout groups
* \brief QKV layout groups
*/
enum NVTE_QKV_Layout_Group {
/*! 3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD */
Expand Down
Loading

0 comments on commit b8eea8a

Please sign in to comment.