@@ -1024,59 +1024,82 @@ def swap_key_value_dict(self, batch_indices):
1024
1024
1025
1025
1026
1026
@torch .no_grad ()
1027
- def get_swa_mask (
1028
- window_size : Tuple [int , int ],
1027
+ def get_full_mask (
1029
1028
max_seqlen_q : int ,
1030
1029
max_seqlen_kv : int ,
1031
1030
attn_mask_type : str = "no_mask" ,
1032
- attention_mask : Optional [Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]] = None ,
1031
+ attention_mask : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]] = None ,
1032
+ window_size : Tuple [int , int ] = None ,
1033
+ attention_type : str = "self" ,
1034
+ bottom_right_alignment : bool = True ,
1033
1035
) -> torch .Tensor :
1034
1036
"""
1035
- Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. Requirements for the
1036
- shapes of `attention_mask` given an `attn_mask_type` are the same as in DotProductAttention.
1037
- For "`causal`" and "`padding_causal`" mask types, the sliding window diagonal is aligned to the
1038
- top left corner of the softmax matrix; for others, the bottom right corner. Note that when padding
1039
- is applied, the bottom right corner comes from the [actual_seqlen_q[i], actual_seqlen_kv[i]] matrix,
1040
- for each batch i, not the [max_seqlen_q, max_seqlen_kv] matrix.::
1037
+ Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
1038
+ `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
1039
+ on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::
1041
1040
1042
1041
attn_mask_type output shape diagonal alignment
1043
1042
--------------------------------------------------------------------------------------------
1044
- no_mask [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
1045
- causal [1, 1, max_seqlen_q, max_seqlen_kv] top left
1046
- causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] bottom right
1047
- padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
1048
- actual sequence lengths
1049
- padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] top left
1050
- padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] bottom right, based on
1051
- actual sequence lengths
1052
- arbitrary same as attention_mask bottom right
1043
+ no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
1044
+ causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
1045
+ causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
1046
+ padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
1047
+ padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
1048
+ padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
1049
+ arbitrary same as attention_mask follow bottom_right_alignment
1050
+
1051
+ .. note::
1052
+
1053
+ For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
1054
+ diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
1055
+ i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
1056
+ max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
1057
+ [[False, False, True, True], [False, False, False, False]],
1058
+ [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
1059
+ shape and is,::
1060
+
1061
+ [[[False, False, False, True],
1062
+ [False, False, False, True],
1063
+ [ True, True, True, True],
1064
+ [ True, True, True, True]],
1065
+ [[False, True, True, True],
1066
+ [False, True, True, True],
1067
+ [False, True, True, True],
1068
+ [False, True, True, True]]]
1053
1069
1054
1070
Parameters
1055
1071
----------
1056
- window_size: Tuple[int, int]
1057
- Sliding window size for local attention, where query at position i attends to keys
1058
- in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
1059
- + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
1060
- window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
1061
- map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
1062
- `attn_mask_type`.
1063
1072
max_seqlen_q: int
1064
1073
Maximum sequence length for queries.
1065
1074
max_seqlen_kv: int
1066
1075
Maximum sequence length for keys and values.
1067
1076
attn_mask_type: str, default = `no_mask`
1068
1077
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
1069
1078
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
1070
- attention_mask: Optional[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor] ]],
1079
+ attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
1071
1080
default = `None`
1072
- Boolean tensor(s) used to mask out attention softmax input.
1081
+ Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
1082
+ for the requirements of `attention_mask` for different `attn_mask_type`s.
1083
+ window_size: Tuple[int, int], default = `None`
1084
+ Sliding window size for local attention, where query at position i attends to keys
1085
+ in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
1086
+ + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
1087
+ window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
1088
+ map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
1089
+ `attn_mask_type`.
1090
+ attention_type: str, default = "self"
1091
+ Attention type, {"self", "cross"}
1092
+ bottom_right_alignment: bool, default = `True`
1093
+ Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
1094
+ or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
1095
+ specifies "causal" or "causal_bottom_right".
1073
1096
1074
1097
Returns
1075
1098
----------
1076
- attn_mask_type: str, default = `no_mask`
1077
- New attention mask type "arbitrary".
1099
+ attn_mask_type: str
1100
+ For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
1078
1101
attention_mask: torch.Tensor
1079
- Result after combining input mask and sliding window mask.
1102
+ The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
1080
1103
actual_seqlens_q: torch.Tensor
1081
1104
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
1082
1105
For other masks, `None`.
@@ -1101,7 +1124,7 @@ def get_swa_mask(
1101
1124
actual_seqlens_q = None
1102
1125
actual_seqlens_kv = None
1103
1126
if "padding" in attn_mask_type :
1104
- if max_seqlen_q == max_seqlen_kv :
1127
+ if attention_type == "self" :
1105
1128
attention_mask = torch .logical_or (
1106
1129
attention_mask .squeeze (1 ).unsqueeze (3 ), attention_mask
1107
1130
)
@@ -1119,13 +1142,16 @@ def get_swa_mask(
1119
1142
) - torch .arange (max_seqlen_kv , dtype = torch .int32 , device = "cuda" ).view (1 , 1 , 1 , max_seqlen_kv )
1120
1143
swa_left = None
1121
1144
swa_right = None
1122
- if attn_mask_type in ["no_mask" , "causal_bottom_right" , "arbitrary" ]:
1145
+ if attn_mask_type == "causal_bottom_right" or (
1146
+ attn_mask_type in ["no_mask" , "arbitrary" ] and bottom_right_alignment ):
1123
1147
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size [0 ]
1124
1148
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size [1 ]
1125
- elif attn_mask_type in ["causal" , "padding_causal" ]:
1149
+ elif attn_mask_type in ["causal" , "padding_causal" ] or (
1150
+ attn_mask_type in ["no_mask" , "padding" , "arbitrary" ] and not bottom_right_alignment ):
1126
1151
swa_left = mask - window_size [0 ]
1127
1152
swa_right = mask + window_size [1 ]
1128
- elif attn_mask_type in ["padding" , "padding_causal_bottom_right" ]:
1153
+ elif attn_mask_type == "padding_causal_bottom_right" or (
1154
+ attn_mask_type == "padding" and bottom_right_alignment ):
1129
1155
batch_size = attention_mask .shape [0 ]
1130
1156
swa_left = mask .expand (batch_size , 1 , max_seqlen_q , max_seqlen_kv ) + (
1131
1157
actual_seqlens_kv - actual_seqlens_q - window_size [0 ]
@@ -4821,8 +4847,10 @@ def forward(
4821
4847
key_layer .shape [0 ],
4822
4848
)
4823
4849
4824
- attn_mask_type , attention_mask , actual_seqlens_q , actual_seqlens_kv = get_swa_mask (
4825
- window_size , max_seqlen_q , max_seqlen_kv , attn_mask_type , attention_mask
4850
+ attn_mask_type , attention_mask , actual_seqlens_q , actual_seqlens_kv = get_full_mask (
4851
+ max_seqlen_q , max_seqlen_kv , attn_mask_type = attn_mask_type ,
4852
+ attention_mask = attention_mask , window_size = window_size ,
4853
+ attention_type = self .attention_type ,
4826
4854
)
4827
4855
4828
4856
batch_size , seqlen = query_layer .shape [1 ], query_layer .shape [0 ]
0 commit comments