Skip to content

Commit

Permalink
[C/PyTorch] Add support for multi-latent attention (MLA) (#1039)
Browse files Browse the repository at this point in the history
* add multi-latent attention for DPA

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix Jax/Paddle API

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* fix typo in test script

Signed-off-by: Charlene Yang <[email protected]>

* fix too-many-boolean lint error

Signed-off-by: Charlene Yang <[email protected]>

* Revert "fix lint"

This reverts commit 67399a3.

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix stride check in get_qkv_layout

Signed-off-by: Charlene Yang <[email protected]>

* WIP: fix layout_thd tests

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP: debug info

Signed-off-by: Charlene Yang <[email protected]>

* fix merge conflict

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix thd pad_between_seqs=False/True tests

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Aug 6, 2024
1 parent 27c6342 commit 87939be
Show file tree
Hide file tree
Showing 15 changed files with 343 additions and 250 deletions.
137 changes: 95 additions & 42 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def __init__(
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim: int,
head_dim_qk: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
Expand All @@ -91,9 +92,10 @@ def __init__(
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim = head_dim
self.hidden_size = num_heads * head_dim
self.hidden_size_kv = num_gqa_groups * head_dim
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
Expand Down Expand Up @@ -137,7 +139,11 @@ def _get_attention_backends(
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128:
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True

fused_attn_backends = []
Expand All @@ -153,7 +159,8 @@ def test():
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
Expand Down Expand Up @@ -218,11 +225,12 @@ def test_dot_product_attention(
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd"
qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
else:
qkv_layout = "sbhd_sb2hd"
qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")

Expand All @@ -241,14 +249,17 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs:
if pad_between_seqs and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
):
flash_attn_supported = True

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")

is_training = config.head_dim <= 128
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
Expand Down Expand Up @@ -343,6 +354,38 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
"""Test DotProductAttention module with Multi-Latent Attention (MLA)"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
Expand Down Expand Up @@ -586,14 +629,16 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)


def _run_dot_product_attention(
Expand Down Expand Up @@ -736,7 +781,8 @@ def _run_dot_product_attention(
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"dqk": config.head_dim_qk,
"dv": config.head_dim_v,
"t": cu_seqlens_q_after_pad[-1],
"tg": cu_seqlens_kv_after_pad[-1],
"3": 3,
Expand All @@ -753,12 +799,16 @@ def _run_dot_product_attention(
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
if i == 2:
layout = layout.replace("d", "dv")
else:
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]:
if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
Expand All @@ -772,7 +822,7 @@ def _run_dot_product_attention(
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]:
if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_kv_after_pad[i - 1],
Expand Down Expand Up @@ -811,13 +861,14 @@ def _run_dot_product_attention(
# Create output gradient
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
qkv_format_kv = qkv_format_kv.replace("d", "dv")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
if qkv_format == "thd" and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if qkv_format_kv == "t_h_d":
if qkv_format_kv == "t_h_dv":
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
Expand Down Expand Up @@ -851,7 +902,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
# Set up model
block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
Expand Down Expand Up @@ -906,9 +957,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
for i in range(1, config.batch_size + 1):
valid_range_q = (
cu_seqlens_q_after_pad[i - 1],
Expand All @@ -919,15 +971,16 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
Expand Down Expand Up @@ -1168,7 +1221,7 @@ def _run_transformer_layer(
# Create RoPE
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")

# Set up model
Expand All @@ -1183,7 +1236,7 @@ def _run_transformer_layer(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
kv_channels=config.head_dim_qk,
self_attn_mask_type=config.attn_mask_type,
tp_group=None,
tp_size=1,
Expand Down Expand Up @@ -1356,7 +1409,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim,
kv_channels=config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
layer_number=1,
Expand Down Expand Up @@ -1387,7 +1440,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
Expand Down Expand Up @@ -1531,7 +1584,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
with fp8_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
Expand Down Expand Up @@ -1560,7 +1613,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
Expand Down Expand Up @@ -1732,7 +1785,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
inp = 0.0001 * torch.randint(
-100,
100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
dtype=dtype,
device="cuda",
requires_grad=True,
Expand All @@ -1743,7 +1796,7 @@ def _run_custom_mha_fp8(dtype, config, backend):

out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
dtype=dtype,
device="cuda",
)
Expand All @@ -1766,7 +1819,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
return (
out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
).contiguous(),
)

Expand Down Expand Up @@ -1809,7 +1862,7 @@ def get_dummy_cuda_rng_tracker():

block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
Expand Down Expand Up @@ -2105,7 +2158,7 @@ def __init__(self, config, params_dtype: torch.dtype = torch.float32):
self.p_dropout = config.dropout_p
self.h = config.num_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.head_dim = config.head_dim_qk
self.fast_zero_fill = True
self.mask_type = config.attn_mask_type

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def test_export_core_attention(

model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
k_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
Expand Down
Loading

0 comments on commit 87939be

Please sign in to comment.