Skip to content

Commit 38460c3

Browse files
committed
revamp to get full mask
Signed-off-by: Charlene Yang <[email protected]>
1 parent 3ed8322 commit 38460c3

File tree

1 file changed

+64
-36
lines changed

1 file changed

+64
-36
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,59 +1024,82 @@ def swap_key_value_dict(self, batch_indices):
10241024

10251025

10261026
@torch.no_grad()
1027-
def get_swa_mask(
1028-
window_size: Tuple[int, int],
1027+
def get_full_mask(
10291028
max_seqlen_q: int,
10301029
max_seqlen_kv: int,
10311030
attn_mask_type: str = "no_mask",
1032-
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
1031+
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
1032+
window_size: Tuple[int, int] = None,
1033+
attention_type: str = "self",
1034+
bottom_right_alignment: bool = True,
10331035
) -> torch.Tensor:
10341036
"""
1035-
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. Requirements for the
1036-
shapes of `attention_mask` given an `attn_mask_type` are the same as in DotProductAttention.
1037-
For "`causal`" and "`padding_causal`" mask types, the sliding window diagonal is aligned to the
1038-
top left corner of the softmax matrix; for others, the bottom right corner. Note that when padding
1039-
is applied, the bottom right corner comes from the [actual_seqlen_q[i], actual_seqlen_kv[i]] matrix,
1040-
for each batch i, not the [max_seqlen_q, max_seqlen_kv] matrix.::
1037+
Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
1038+
`attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
1039+
on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::
10411040
10421041
attn_mask_type output shape diagonal alignment
10431042
--------------------------------------------------------------------------------------------
1044-
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
1045-
causal [1, 1, max_seqlen_q, max_seqlen_kv] top left
1046-
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
1047-
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
1048-
actual sequence lengths
1049-
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] top left
1050-
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
1051-
actual sequence lengths
1052-
arbitrary same as attention_mask bottom right
1043+
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
1044+
causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
1045+
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
1046+
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
1047+
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
1048+
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
1049+
arbitrary same as attention_mask follow bottom_right_alignment
1050+
1051+
.. note::
1052+
1053+
For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
1054+
diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
1055+
i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
1056+
max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
1057+
[[False, False, True, True], [False, False, False, False]],
1058+
[[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
1059+
shape and is,::
1060+
1061+
[[[False, False, False, True],
1062+
[False, False, False, True],
1063+
[ True, True, True, True],
1064+
[ True, True, True, True]],
1065+
[[False, True, True, True],
1066+
[False, True, True, True],
1067+
[False, True, True, True],
1068+
[False, True, True, True]]]
10531069
10541070
Parameters
10551071
----------
1056-
window_size: Tuple[int, int]
1057-
Sliding window size for local attention, where query at position i attends to keys
1058-
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
1059-
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
1060-
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
1061-
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
1062-
`attn_mask_type`.
10631072
max_seqlen_q: int
10641073
Maximum sequence length for queries.
10651074
max_seqlen_kv: int
10661075
Maximum sequence length for keys and values.
10671076
attn_mask_type: str, default = `no_mask`
10681077
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
10691078
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
1070-
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
1079+
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
10711080
default = `None`
1072-
Boolean tensor(s) used to mask out attention softmax input.
1081+
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
1082+
for the requirements of `attention_mask` for different `attn_mask_type`s.
1083+
window_size: Tuple[int, int], default = `None`
1084+
Sliding window size for local attention, where query at position i attends to keys
1085+
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
1086+
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
1087+
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
1088+
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
1089+
`attn_mask_type`.
1090+
attention_type: str, default = "self"
1091+
Attention type, {"self", "cross"}
1092+
bottom_right_alignment: bool, default = `True`
1093+
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
1094+
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
1095+
specifies "causal" or "causal_bottom_right".
10731096
10741097
Returns
10751098
----------
1076-
attn_mask_type: str, default = `no_mask`
1077-
New attention mask type "arbitrary".
1099+
attn_mask_type: str
1100+
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
10781101
attention_mask: torch.Tensor
1079-
Result after combining input mask and sliding window mask.
1102+
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
10801103
actual_seqlens_q: torch.Tensor
10811104
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
10821105
For other masks, `None`.
@@ -1101,7 +1124,7 @@ def get_swa_mask(
11011124
actual_seqlens_q = None
11021125
actual_seqlens_kv = None
11031126
if "padding" in attn_mask_type:
1104-
if max_seqlen_q == max_seqlen_kv:
1127+
if attention_type == "self":
11051128
attention_mask = torch.logical_or(
11061129
attention_mask.squeeze(1).unsqueeze(3), attention_mask
11071130
)
@@ -1119,13 +1142,16 @@ def get_swa_mask(
11191142
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
11201143
swa_left = None
11211144
swa_right = None
1122-
if attn_mask_type in ["no_mask", "causal_bottom_right", "arbitrary"]:
1145+
if attn_mask_type == "causal_bottom_right" or (
1146+
attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment):
11231147
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
11241148
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
1125-
elif attn_mask_type in ["causal", "padding_causal"]:
1149+
elif attn_mask_type in ["causal", "padding_causal"] or (
1150+
attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment):
11261151
swa_left = mask - window_size[0]
11271152
swa_right = mask + window_size[1]
1128-
elif attn_mask_type in ["padding", "padding_causal_bottom_right"]:
1153+
elif attn_mask_type == "padding_causal_bottom_right" or (
1154+
attn_mask_type == "padding" and bottom_right_alignment):
11291155
batch_size = attention_mask.shape[0]
11301156
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
11311157
actual_seqlens_kv - actual_seqlens_q - window_size[0]
@@ -4821,8 +4847,10 @@ def forward(
48214847
key_layer.shape[0],
48224848
)
48234849

4824-
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_swa_mask(
4825-
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
4850+
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
4851+
max_seqlen_q, max_seqlen_kv, attn_mask_type=attn_mask_type,
4852+
attention_mask=attention_mask, window_size=window_size,
4853+
attention_type=self.attention_type,
48264854
)
48274855

48284856
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

0 commit comments

Comments
 (0)