Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/PyTorch] Add support for multi-latent attention (MLA) #1039

Merged
merged 19 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 95 additions & 42 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def __init__(
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim: int,
head_dim_qk: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
Expand All @@ -91,9 +92,10 @@ def __init__(
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim = head_dim
self.hidden_size = num_heads * head_dim
self.hidden_size_kv = num_gqa_groups * head_dim
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
Expand Down Expand Up @@ -137,7 +139,11 @@ def _get_attention_backends(
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128:
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True

fused_attn_backends = []
Expand All @@ -153,7 +159,8 @@ def test():
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim=config.head_dim,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
Expand Down Expand Up @@ -218,11 +225,12 @@ def test_dot_product_attention(
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd"
qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
else:
qkv_layout = "sbhd_sb2hd"
qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")

Expand All @@ -241,14 +249,17 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs:
if pad_between_seqs and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
):
flash_attn_supported = True

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")

is_training = config.head_dim <= 128
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
Expand Down Expand Up @@ -343,6 +354,38 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
"""Test DotProductAttention module with Multi-Latent Attention (MLA)"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
Expand Down Expand Up @@ -586,14 +629,16 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)


def _run_dot_product_attention(
Expand Down Expand Up @@ -736,7 +781,8 @@ def _run_dot_product_attention(
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"dqk": config.head_dim_qk,
"dv": config.head_dim_v,
"t": cu_seqlens_q_after_pad[-1],
"tg": cu_seqlens_kv_after_pad[-1],
"3": 3,
Expand All @@ -753,12 +799,16 @@ def _run_dot_product_attention(
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
if i == 2:
layout = layout.replace("d", "dv")
else:
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]:
if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
Expand All @@ -772,7 +822,7 @@ def _run_dot_product_attention(
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]:
if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_kv_after_pad[i - 1],
Expand Down Expand Up @@ -811,13 +861,14 @@ def _run_dot_product_attention(
# Create output gradient
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
qkv_format_kv = qkv_format_kv.replace("d", "dv")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
if qkv_format == "thd" and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if qkv_format_kv == "t_h_d":
if qkv_format_kv == "t_h_dv":
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
Expand Down Expand Up @@ -851,7 +902,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
# Set up model
block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
Expand Down Expand Up @@ -906,9 +957,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
for i in range(1, config.batch_size + 1):
valid_range_q = (
cu_seqlens_q_after_pad[i - 1],
Expand All @@ -919,15 +971,16 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
Expand Down Expand Up @@ -1168,7 +1221,7 @@ def _run_transformer_layer(
# Create RoPE
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")

# Set up model
Expand All @@ -1183,7 +1236,7 @@ def _run_transformer_layer(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim,
kv_channels=config.head_dim_qk,
self_attn_mask_type=config.attn_mask_type,
tp_group=None,
tp_size=1,
Expand Down Expand Up @@ -1356,7 +1409,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim,
kv_channels=config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
layer_number=1,
Expand Down Expand Up @@ -1387,7 +1440,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
Expand Down Expand Up @@ -1531,7 +1584,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
with fp8_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
Expand Down Expand Up @@ -1560,7 +1613,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
Expand Down Expand Up @@ -1732,7 +1785,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
inp = 0.0001 * torch.randint(
-100,
100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
dtype=dtype,
device="cuda",
requires_grad=True,
Expand All @@ -1743,7 +1796,7 @@ def _run_custom_mha_fp8(dtype, config, backend):

out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
dtype=dtype,
device="cuda",
)
Expand All @@ -1766,7 +1819,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
return (
out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
).contiguous(),
)

Expand Down Expand Up @@ -1809,7 +1862,7 @@ def get_dummy_cuda_rng_tracker():

block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
Expand Down Expand Up @@ -2105,7 +2158,7 @@ def __init__(self, config, params_dtype: torch.dtype = torch.float32):
self.p_dropout = config.dropout_p
self.h = config.num_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.head_dim = config.head_dim_qk
self.fast_zero_fill = True
self.mask_type = config.attn_mask_type

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def test_export_core_attention(

model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
k_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
Expand Down
Loading
Loading