@@ -1304,6 +1304,8 @@ def forward(
1304
1304
qkv_layout : str = "sbh3d" ,
1305
1305
cu_seqlens_q : Optional [torch .Tensor ] = None ,
1306
1306
cu_seqlens_kv : Optional [torch .Tensor ] = None ,
1307
+ max_seqlen_q : Optional [int ] = None ,
1308
+ max_seqlen_kv : Optional [int ] = None ,
1307
1309
attn_mask_type : str = "causal" ,
1308
1310
window_size : Optional [Tuple [int , int ]] = None ,
1309
1311
cp_group : Optional [dist_group_type ] = None ,
@@ -1346,10 +1348,10 @@ def forward(
1346
1348
for x in (query_layer , key_layer , value_layer )]
1347
1349
1348
1350
global _cu_seqlens_q , _cu_seqlens_kv , _indices_q , _indices_kv
1349
- batch_size , max_seqlen_q , max_seqlen_kv = (
1350
- query_layer .shape [0 ], query_layer .shape [1 ], key_layer .shape [1 ])
1351
+ batch_size = query_layer .shape [0 ]
1351
1352
1352
1353
if qkv_format in ['sbhd' , 'bshd' ]:
1354
+ max_seqlen_q , max_seqlen_kv = query_layer .shape [1 ], key_layer .shape [1 ]
1353
1355
if not context_parallel :
1354
1356
# [b * s, h, d]
1355
1357
query_layer , key_layer , value_layer = [
@@ -1422,10 +1424,12 @@ def forward(
1422
1424
), "flash-attn v2 is required for variable sequence length support!"
1423
1425
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
1424
1426
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
1425
- seqlens_q = cu_seqlens_q [1 :] - cu_seqlens_q [:- 1 ]
1426
- seqlens_kv = cu_seqlens_kv [1 :] - cu_seqlens_kv [:- 1 ]
1427
- max_seqlen_q = seqlens_q .max ().item ()
1428
- max_seqlen_kv = seqlens_kv .max ().item ()
1427
+ if max_seqlen_q is None :
1428
+ seqlens_q = cu_seqlens_q [1 :] - cu_seqlens_q [:- 1 ]
1429
+ max_seqlen_q = seqlens_q .max ().item ()
1430
+ if max_seqlen_kv is None :
1431
+ seqlens_kv = cu_seqlens_kv [1 :] - cu_seqlens_kv [:- 1 ]
1432
+ max_seqlen_kv = seqlens_kv .max ().item ()
1429
1433
1430
1434
if context_parallel :
1431
1435
assert (
@@ -1754,6 +1758,8 @@ def forward(
1754
1758
qkv_layout : str = "sbh3d" ,
1755
1759
cu_seqlens_q : Optional [torch .Tensor ] = None ,
1756
1760
cu_seqlens_kv : Optional [torch .Tensor ] = None ,
1761
+ max_seqlen_q : Optional [int ] = None ,
1762
+ max_seqlen_kv : Optional [int ] = None ,
1757
1763
attn_mask_type : str = "causal" ,
1758
1764
attention_mask : Optional [Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]] = None ,
1759
1765
fused_attention_backend :
@@ -2104,6 +2110,8 @@ def forward(
2104
2110
qkv_format : Optional [str ] = None ,
2105
2111
cu_seqlens_q : Optional [torch .Tensor ] = None ,
2106
2112
cu_seqlens_kv : Optional [torch .Tensor ] = None ,
2113
+ max_seqlen_q : Optional [int ] = None ,
2114
+ max_seqlen_kv : Optional [int ] = None ,
2107
2115
attn_mask_type : Optional [str ] = None ,
2108
2116
window_size : Optional [Tuple [int , int ]] = None ,
2109
2117
checkpoint_core_attention : bool = False ,
@@ -2176,6 +2184,12 @@ def forward(
2176
2184
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
2177
2185
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
2178
2186
with shape [batch_size + 1] and dtype torch.int32.
2187
+ max_seqlen_q: Optional[int], default = `None`
2188
+ Maximum sequence length in `query_layer`.
2189
+ Calculated from `cu_seqlens_q` if not provided.
2190
+ max_seqlen_kv: Optional[int], default = `None`
2191
+ Maximum sequence length in `key_layer` and `value_layer`.
2192
+ Calculated from `cu_seqlens_kv` if not provided.
2179
2193
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
2180
2194
`arbitrary`}, default = `None`. Type of attention mask passed into
2181
2195
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
@@ -2238,10 +2252,12 @@ def forward(
2238
2252
assert (cu_seqlens_q .dtype == torch .int32
2239
2253
and cu_seqlens_kv .dtype == torch .int32
2240
2254
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
2241
- seqlens_q = cu_seqlens_q [1 :] - cu_seqlens_q [:- 1 ]
2242
- seqlens_kv = cu_seqlens_kv [1 :] - cu_seqlens_kv [:- 1 ]
2243
- max_seqlen_q = seqlens_q .max ().item ()
2244
- max_seqlen_kv = seqlens_kv .max ().item ()
2255
+ if max_seqlen_q is None :
2256
+ seqlens_q = cu_seqlens_q [1 :] - cu_seqlens_q [:- 1 ]
2257
+ max_seqlen_q = seqlens_q .max ().item ()
2258
+ if max_seqlen_kv is None :
2259
+ seqlens_kv = cu_seqlens_kv [1 :] - cu_seqlens_kv [:- 1 ]
2260
+ max_seqlen_kv = seqlens_kv .max ().item ()
2245
2261
2246
2262
if qkv_format in ['sbhd' , 'bshd' ]:
2247
2263
assert (all (len (x .shape ) == 4 for x in (query_layer , key_layer , value_layer ))
@@ -2405,7 +2421,9 @@ def forward(
2405
2421
window_size = window_size ,
2406
2422
cp_group = self .cp_group ,
2407
2423
cp_global_ranks = self .cp_global_ranks ,
2408
- cp_stream = self .cp_stream )
2424
+ cp_stream = self .cp_stream ,
2425
+ max_seqlen_q = max_seqlen_q ,
2426
+ max_seqlen_kv = max_seqlen_kv )
2409
2427
2410
2428
assert (
2411
2429
self .cp_group is None or get_distributed_world_size (self .cp_group ) == 1
@@ -2428,7 +2446,9 @@ def forward(
2428
2446
fused_attention_backend = fused_attention_backend ,
2429
2447
core_attention_bias_type = core_attention_bias_type ,
2430
2448
core_attention_bias = core_attention_bias ,
2431
- fast_zero_fill = fast_zero_fill )
2449
+ fast_zero_fill = fast_zero_fill ,
2450
+ max_seqlen_q = max_seqlen_q ,
2451
+ max_seqlen_kv = max_seqlen_kv )
2432
2452
return self .fused_attention (query_layer , key_layer , value_layer ,
2433
2453
qkv_layout = qkv_layout ,
2434
2454
cu_seqlens_q = cu_seqlens_q ,
@@ -2438,7 +2458,9 @@ def forward(
2438
2458
fused_attention_backend = fused_attention_backend ,
2439
2459
core_attention_bias_type = core_attention_bias_type ,
2440
2460
core_attention_bias = core_attention_bias ,
2441
- fast_zero_fill = fast_zero_fill )
2461
+ fast_zero_fill = fast_zero_fill ,
2462
+ max_seqlen_q = max_seqlen_q ,
2463
+ max_seqlen_kv = max_seqlen_kv )
2442
2464
2443
2465
if _NVTE_DEBUG :
2444
2466
print ("[DotProductAttention]: using unfused DPA" )
0 commit comments