Skip to content

Commit b90b638

Browse files
erhoo82timmoon10ksivaman
authored
Provide pre-computed max sequence to remove unnecessary kernels and D2H copies (#555)
* Provide pre-computed max sequence to remove unnecessary kernels and D2H copies Signed-off-by: Sangkug Lym <[email protected]> * Tweak comments Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Sangkug Lym <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent cd798c9 commit b90b638

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,8 @@ def forward(
13041304
qkv_layout: str = "sbh3d",
13051305
cu_seqlens_q: Optional[torch.Tensor] = None,
13061306
cu_seqlens_kv: Optional[torch.Tensor] = None,
1307+
max_seqlen_q: Optional[int] = None,
1308+
max_seqlen_kv: Optional[int] = None,
13071309
attn_mask_type: str = "causal",
13081310
window_size: Optional[Tuple[int, int]] = None,
13091311
cp_group: Optional[dist_group_type] = None,
@@ -1346,10 +1348,10 @@ def forward(
13461348
for x in (query_layer, key_layer, value_layer)]
13471349

13481350
global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv
1349-
batch_size, max_seqlen_q, max_seqlen_kv = (
1350-
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
1351+
batch_size = query_layer.shape[0]
13511352

13521353
if qkv_format in ['sbhd', 'bshd']:
1354+
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
13531355
if not context_parallel:
13541356
# [b * s, h, d]
13551357
query_layer, key_layer, value_layer = [
@@ -1422,10 +1424,12 @@ def forward(
14221424
), "flash-attn v2 is required for variable sequence length support!"
14231425
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
14241426
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
1425-
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
1426-
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
1427-
max_seqlen_q = seqlens_q.max().item()
1428-
max_seqlen_kv = seqlens_kv.max().item()
1427+
if max_seqlen_q is None:
1428+
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
1429+
max_seqlen_q = seqlens_q.max().item()
1430+
if max_seqlen_kv is None:
1431+
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
1432+
max_seqlen_kv = seqlens_kv.max().item()
14291433

14301434
if context_parallel:
14311435
assert (
@@ -1754,6 +1758,8 @@ def forward(
17541758
qkv_layout: str = "sbh3d",
17551759
cu_seqlens_q: Optional[torch.Tensor] = None,
17561760
cu_seqlens_kv: Optional[torch.Tensor] = None,
1761+
max_seqlen_q: Optional[int] = None,
1762+
max_seqlen_kv: Optional[int] = None,
17571763
attn_mask_type: str = "causal",
17581764
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
17591765
fused_attention_backend:
@@ -2104,6 +2110,8 @@ def forward(
21042110
qkv_format: Optional[str] = None,
21052111
cu_seqlens_q: Optional[torch.Tensor] = None,
21062112
cu_seqlens_kv: Optional[torch.Tensor] = None,
2113+
max_seqlen_q: Optional[int] = None,
2114+
max_seqlen_kv: Optional[int] = None,
21072115
attn_mask_type: Optional[str] = None,
21082116
window_size: Optional[Tuple[int, int]] = None,
21092117
checkpoint_core_attention: bool = False,
@@ -2176,6 +2184,12 @@ def forward(
21762184
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
21772185
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
21782186
with shape [batch_size + 1] and dtype torch.int32.
2187+
max_seqlen_q: Optional[int], default = `None`
2188+
Maximum sequence length in `query_layer`.
2189+
Calculated from `cu_seqlens_q` if not provided.
2190+
max_seqlen_kv: Optional[int], default = `None`
2191+
Maximum sequence length in `key_layer` and `value_layer`.
2192+
Calculated from `cu_seqlens_kv` if not provided.
21792193
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
21802194
`arbitrary`}, default = `None`. Type of attention mask passed into
21812195
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
@@ -2238,10 +2252,12 @@ def forward(
22382252
assert (cu_seqlens_q.dtype == torch.int32
22392253
and cu_seqlens_kv.dtype == torch.int32
22402254
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
2241-
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
2242-
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
2243-
max_seqlen_q = seqlens_q.max().item()
2244-
max_seqlen_kv = seqlens_kv.max().item()
2255+
if max_seqlen_q is None:
2256+
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
2257+
max_seqlen_q = seqlens_q.max().item()
2258+
if max_seqlen_kv is None:
2259+
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
2260+
max_seqlen_kv = seqlens_kv.max().item()
22452261

22462262
if qkv_format in ['sbhd', 'bshd']:
22472263
assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
@@ -2405,7 +2421,9 @@ def forward(
24052421
window_size=window_size,
24062422
cp_group=self.cp_group,
24072423
cp_global_ranks=self.cp_global_ranks,
2408-
cp_stream=self.cp_stream)
2424+
cp_stream=self.cp_stream,
2425+
max_seqlen_q=max_seqlen_q,
2426+
max_seqlen_kv=max_seqlen_kv)
24092427

24102428
assert (
24112429
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
@@ -2428,7 +2446,9 @@ def forward(
24282446
fused_attention_backend = fused_attention_backend,
24292447
core_attention_bias_type = core_attention_bias_type,
24302448
core_attention_bias = core_attention_bias,
2431-
fast_zero_fill = fast_zero_fill)
2449+
fast_zero_fill = fast_zero_fill,
2450+
max_seqlen_q=max_seqlen_q,
2451+
max_seqlen_kv=max_seqlen_kv)
24322452
return self.fused_attention(query_layer, key_layer, value_layer,
24332453
qkv_layout = qkv_layout,
24342454
cu_seqlens_q = cu_seqlens_q,
@@ -2438,7 +2458,9 @@ def forward(
24382458
fused_attention_backend = fused_attention_backend,
24392459
core_attention_bias_type = core_attention_bias_type,
24402460
core_attention_bias = core_attention_bias,
2441-
fast_zero_fill = fast_zero_fill)
2461+
fast_zero_fill = fast_zero_fill,
2462+
max_seqlen_q=max_seqlen_q,
2463+
max_seqlen_kv=max_seqlen_kv)
24422464

24432465
if _NVTE_DEBUG:
24442466
print("[DotProductAttention]: using unfused DPA")

0 commit comments

Comments
 (0)