Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix get_swa_mask() for padding masks #1281

Merged
merged 11 commits into from
Dec 18, 2024
28 changes: 16 additions & 12 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model):

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_0": ModelConfig(
4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_1": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}


Expand Down
227 changes: 141 additions & 86 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,61 +1024,157 @@ def swap_key_value_dict(self, batch_indices):


@torch.no_grad()
def get_swa_mask(
window_size: Tuple[int, int],
def get_full_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
attn_mask_type: str = "no_mask",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
window_size: Tuple[int, int] = None,
attention_type: str = "self",
bottom_right_alignment: bool = True,
) -> torch.Tensor:
"""
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
and for other mask types, the bottom right corner.
Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
`attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::

attn_mask_type output shape diagonal alignment
--------------------------------------------------------------------------------------------
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
arbitrary same as attention_mask follow bottom_right_alignment

.. note::

For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
[[False, False, True, True], [False, False, False, False]],
[[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
shape and is,::

[[[False, False, False, True],
[False, False, False, True],
[ True, True, True, True],
[ True, True, True, True]],
[[False, True, True, True],
[False, True, True, True],
[False, True, True, True],
[False, True, True, True]]]

Parameters
----------
window_size: Tuple[int, int]
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None`
Boolean tensor(s) used to mask out attention softmax input.
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".

Returns
----------
attn_mask_type: str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor
Combined `attention_mask` (input) and sliding window attention mask.
The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
else, the same shape as input `attention_mask`.
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
For other masks, `None`.
"""
mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
if attn_mask_type in ["causal"]:
left = window_size[0] if window_size[0] != -1 else max_seqlen_q
right = window_size[1] if window_size[1] != -1 else max_seqlen_q
mask_upper = torch.triu(mask, diagonal=-left)
mask_lower = torch.tril(mask_upper, diagonal=right)
else:
left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
attn_mask_type = "arbitrary"
mask = mask_lower.logical_not()
# perform basic checks
change_type = window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
)
if window_size is None:
window_size = (-1, -1)
if "causal" in attn_mask_type:
window_size = (window_size[0], 0)
window_size = (
max_seqlen_kv if window_size[0] == -1 else window_size[0],
max_seqlen_q if window_size[1] == -1 else window_size[1],
)

# apply padding mask
actual_seqlens_q = None
actual_seqlens_kv = None
if "padding" in attn_mask_type:
if attention_type == "self":
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
m = attention_mask.logical_not()
actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)

# apply SWA mask
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
swa_left = None
swa_right = None
if attn_mask_type == "causal_bottom_right" or (
attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
):
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
elif attn_mask_type in ["causal", "padding_causal"] or (
attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
):
swa_left = mask - window_size[0]
swa_right = mask + window_size[1]
elif attn_mask_type == "padding_causal_bottom_right" or (
attn_mask_type == "padding" and bottom_right_alignment
):
batch_size = attention_mask.shape[0]
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q - window_size[0]
).view(batch_size, 1, 1, 1)
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
if attention_mask is not None:
mask = torch.logical_and(attention_mask, mask)
return attn_mask_type, mask
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
attention_mask = swa_mask

# change mask type
if change_type:
attn_mask_type = "arbitrary"

return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv


@torch.no_grad()
Expand Down Expand Up @@ -4733,6 +4829,7 @@ def forward(
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
Expand All @@ -4752,53 +4849,15 @@ def forward(
query_layer.shape[0],
key_layer.shape[0],
)
if "padding" in attn_mask_type:
if self.attention_type == "self":
assert attention_mask.shape == (
batch_size,
1,
1,
max_seqlen_q,
), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
assert (
len(attention_mask) == 2
and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
), (
"attention_mask should be a tuple of two tensors with shapes "
"[b, 1, 1, sq] and [b, 1, 1, skv]!"
)
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
mask = attention_mask.squeeze(1).logical_not()
actual_seqlens_q = mask[:, :, 0].sum(dim=1)
actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if attn_mask_type == "padding_causal":
attention_mask = torch.logical_or(
torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
attention_mask,
)
if attn_mask_type == "padding_causal_bottom_right":
attention_mask = torch.logical_or(
torch.where(
mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
+ (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
< 0,
1,
0,
),
attention_mask,
)

attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
max_seqlen_q,
max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
attention_type=self.attention_type,
)

batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
Expand Down Expand Up @@ -8274,12 +8333,6 @@ def forward(
)

if use_unfused_attention:
if window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
):
attn_mask_type, attention_mask = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
Expand All @@ -8291,6 +8344,7 @@ def forward(
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
Expand All @@ -8304,6 +8358,7 @@ def forward(
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
Expand Down
Loading