diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b80e2fe9fa..68f645a7f5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -108,6 +108,7 @@ _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") +_flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") _flash_attn_3_plus = False _use_flash_attn_3 = False _flash_attn_3_installation_steps = """\ @@ -135,13 +136,19 @@ from flashattn_hopper.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_varlen_forward as flash_attn_varlen_fwd_v3, + ) + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_varlen_backward as flash_attn_varlen_bwd_v3, + ) _use_flash_attn_3 = True if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward - from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward + from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as flash_attn_varlen_fwd + from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as flash_attn_varlen_bwd from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd _attention_backends = { @@ -460,9 +467,6 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and use_flash_attention: - if _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for context parallelism") - _use_flash_attn_3 = False if fp8 and fp8_meta["recipe"].fp8_dpa: logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" @@ -1362,12 +1366,15 @@ def flash_attn_p2p_communicate( def flash_attn_fwd_out_correction( out: torch.Tensor, out_per_step: torch.Tensor, - seq_dim: int, softmax_lse: torch.Tensor, softmax_lse_per_step: torch.Tensor, + movedim_src: int, + movedim_dst: int, ): """Merge partial outputs of each step in Attention with context parallelism""" - softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim( + movedim_src, movedim_dst + ) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) out_corrected = out_per_step * softmax_lse_corrected_exp out.add_(out_corrected) @@ -1693,13 +1700,25 @@ def forward( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" - fa_optional_forward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None + + softmax_lse_in_packed_format = not use_fused_attention and ( + _flash_attn_2_6_0_plus or _use_flash_attn_3 + ) + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if _use_flash_attn_3: + flash_attn_fwd = flash_attn_varlen_fwd_v3 + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + else: + flash_attn_fwd = flash_attn_varlen_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + if _flash_attn_2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_forward_kwargs["block_table"] = None # Flash Attn inputs q_inputs = [None, None] @@ -1840,16 +1859,7 @@ def forward( q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -1857,12 +1867,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q, max_seqlen_kv, - dropout_p, - softmax_scale, causal=True, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] elif i <= rank: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -1952,18 +1963,9 @@ def forward( kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = (-1, -1) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = (-1, -1) + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -1971,12 +1973,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q, max_seqlen_kv // 2, - dropout_p, - softmax_scale, causal=False, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2075,18 +2078,9 @@ def forward( ) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = (-1, -1) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = (-1, -1) + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -2094,12 +2088,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q // 2, max_seqlen_kv, - dropout_p, - softmax_scale, causal=False, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] else: if pad_between_seqs_q: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2167,16 +2162,7 @@ def forward( q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) # [2, b, sk, np, hn] -> [2, b*sk, np, hn] kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - ( - _, - _, - _, - _, - out_per_step[i], - softmax_lse_per_step[i], - _, - rng_states[i], - ) = _flash_attn_forward( + fa_outputs = flash_attn_fwd( q_inputs[i % 2], kv_inputs[i % 2][0], kv_inputs[i % 2][1], @@ -2184,12 +2170,13 @@ def forward( cu_seqlens_kv_per_step[i], max_seqlen_q, max_seqlen_kv, - dropout_p, - softmax_scale, causal=False, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] if i > 0: # wait until fwd restuls correction of last step is done @@ -2199,6 +2186,11 @@ def forward( if use_fused_attention: # [b, np, sq, 1] -> [b, np, sq] softmax_lse_per_step[i - 1].squeeze_(-1) + if qkv_format != "thd" and softmax_lse_in_packed_format: + # [np, t] -> [np, b, sq] + softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view( + q.shape[-2], q.shape[0], -1 + ) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: @@ -2213,7 +2205,8 @@ def forward( out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": - # [b, np, sq] -> [b, np, 2, sq//2] + # [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format + # [np, b, sq] -> [np, b, 2, sq//2] lse in packed format softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) @@ -2227,7 +2220,7 @@ def forward( softmax_lse, softmax_lse_per_step[i - 1], cu_seqlens_q_padded, - max_seqlen_q, + softmax_lse_in_packed_format, ) else: flash_attn_fwd_softmax_lse_correction( @@ -2253,9 +2246,10 @@ def forward( flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), out_per_step[i], - seq_dim, softmax_lse, softmax_lse_per_step[i], + 0 if softmax_lse_in_packed_format else 2, + 2 if softmax_lse_in_packed_format else seq_dim, ) elif qkv_format == "thd": tex.thd_out_correction( @@ -2265,15 +2259,17 @@ def forward( softmax_lse_per_step[i], cu_seqlens_q_padded, False, + softmax_lse_in_packed_format, ) else: if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction( out_, out_per_step[i], - seq_dim, softmax_lse_[..., 1, :], softmax_lse_per_step[i], + 0 if softmax_lse_in_packed_format else 2, + 2 if softmax_lse_in_packed_format else seq_dim, ) elif qkv_format == "thd": tex.thd_out_correction( @@ -2283,8 +2279,12 @@ def forward( softmax_lse_per_step[i], cu_seqlens_q_padded, True, + softmax_lse_in_packed_format, ) + if qkv_format != "thd" and softmax_lse_in_packed_format: + # [np, b, sq] -> [np, t] + softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1) kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) @@ -2430,10 +2430,14 @@ def backward(ctx, dout): else: attn_dbias = None + softmax_lse_in_packed_format = not ctx.use_fused_attention and ( + _flash_attn_2_6_0_plus or _use_flash_attn_3 + ) + if causal: - if ctx.qkv_format == "thd": + if ctx.qkv_format == "thd" or softmax_lse_in_packed_format: softmax_lse_ = tex.thd_read_second_half_lse( - softmax_lse, cu_seqlens_q_padded, ctx.max_seqlen_q + softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format ) else: # [b, np, sq] -> [b, np, 2, sq//2] @@ -2526,11 +2530,18 @@ def backward(ctx, dout): dout = dout.view(*q.shape) send_recv_reqs = [] - fa_optional_backward_kwargs = {} - if _flash_attn_2_4_plus: - fa_optional_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: - fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if _use_flash_attn_3: + flash_attn_bwd = flash_attn_varlen_bwd_v3 + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + flash_attn_bwd = flash_attn_varlen_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if _flash_attn_2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic for i in range(cp_size): # wait until KV is received @@ -2639,9 +2650,11 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, 0) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, 0) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2655,11 +2668,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - True, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=True, + **fa_backward_kwargs, ) elif i >= (cp_size - rank - 1): if ctx.use_fused_attention: @@ -2733,9 +2743,11 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, -1) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, -1) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2749,11 +2761,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, - ctx.dropout_p, - ctx.softmax_scale, - False, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=False, + **fa_backward_kwargs, ) else: if ctx.use_fused_attention: @@ -2833,9 +2842,11 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, -1) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, -1) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2849,11 +2860,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - False, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=False, + **fa_backward_kwargs, ) else: if ctx.use_fused_attention: @@ -2897,9 +2905,11 @@ def backward(ctx, dout): # [b, sq, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) dout_ = dout.view(-1, *dout.shape[-2:]) - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = (-1, -1) - _flash_attn_backward( + if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = (-1, -1) + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + flash_attn_bwd( dout_, q_, kv_[0], @@ -2913,11 +2923,8 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[cp_size - i - 1], ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - False, - rng_state=rng_states[cp_size - i - 1], - **fa_optional_backward_kwargs, + causal=False, + **fa_backward_kwargs, ) if ctx.fp8: @@ -3251,11 +3258,19 @@ def forward( assert ( use_fused_attention or _flash_attn_2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" - fa_optional_forward_kwargs = {} - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None + + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if _use_flash_attn_3: + flash_attn_fwd = flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = flash_attn_varlen_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if _flash_attn_2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_forward_kwargs["block_table"] = None assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -3353,23 +3368,22 @@ def forward( ) else: q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] - _, _, _, _, out_per_step[i], softmax_lse_per_step[i], _, rng_states[i] = ( - _flash_attn_forward( - q_, - k_, - v_, - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv_, - dropout_p, - softmax_scale, - causal=causal, - return_softmax=False, - window_size=window_size_per_step[i], - **fa_optional_forward_kwargs, - ) + fa_outputs = flash_attn_fwd( + q_, + k_, + v_, + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv_, + causal=causal, + window_size=window_size_per_step[i], + **fa_forward_kwargs, ) + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not _use_flash_attn_3: + rng_states[i] = fa_outputs[7] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): @@ -3459,11 +3473,18 @@ def backward(ctx, dout): local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - fa_optional_backward_kwargs = {} - if _flash_attn_2_4_plus: - fa_optional_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: - fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if _use_flash_attn_3: + flash_attn_bwd = flash_attn_varlen_bwd_v3 + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + flash_attn_bwd = flash_attn_varlen_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if _flash_attn_2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -3513,7 +3534,9 @@ def backward(ctx, dout): dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] ] - _flash_attn_backward( + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[i] + flash_attn_bwd( dout_, q_, k_, @@ -3527,12 +3550,9 @@ def backward(ctx, dout): cu_seqlens_kv_per_step[i], ctx.max_seqlen_q, max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - "causal" in ctx.attn_mask_type, + causal="causal" in ctx.attn_mask_type, window_size=window_size_per_step[i], - rng_state=rng_states[i], - **fa_optional_backward_kwargs, + **fa_backward_kwargs, ) # [b*sq//2, np, hn] -> [b, sq//2, np, hn] dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) @@ -3655,13 +3675,22 @@ def forward( or use_fused_attention or _flash_attn_2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" - fa_optional_forward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = window_size - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None + + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if _use_flash_attn_3: + flash_attn_fwd = flash_attn_varlen_fwd_v3 + fa_forward_kwargs["window_size"] = window_size + else: + flash_attn_fwd = flash_attn_varlen_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if _flash_attn_2_3_plus: + fa_forward_kwargs["window_size"] = window_size + if _flash_attn_2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_forward_kwargs["block_table"] = None assert ( q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 @@ -3756,16 +3785,7 @@ def forward( else: # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] - ( - _, - _, - _, - _, - out, - softmax_lse, - _, - rng_state, - ) = _flash_attn_forward( + fa_outputs = flash_attn_fwd( q, k, v, @@ -3773,12 +3793,11 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - dropout_p, - softmax_scale, causal=causal, - return_softmax=False, - **fa_optional_forward_kwargs, + **fa_forward_kwargs, ) + out, softmax_lse = fa_outputs[4], fa_outputs[5] + rng_state = fa_outputs[7] if not _use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] out = out.view(batch_size, -1, *out.shape[-2:]) @@ -3943,13 +3962,21 @@ def backward(ctx, dout): [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) - fa_optional_backward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_backward_kwargs["window_size"] = ctx.window_size - if _flash_attn_2_4_plus: - fa_optional_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: - fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if _use_flash_attn_3: + flash_attn_bwd = flash_attn_varlen_bwd_v3 + fa_backward_kwargs["window_size"] = ctx.window_size + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + flash_attn_bwd = flash_attn_varlen_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if _flash_attn_2_3_plus: + fa_backward_kwargs["window_size"] = ctx.window_size + if _flash_attn_2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic if ctx.use_fused_attention: dq, dk, dv, _ = fused_attn_bwd( @@ -3981,7 +4008,9 @@ def backward(ctx, dout): softmax_lse, rng_state = aux_ctx_tensors out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] - _flash_attn_backward( + if not _use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_state + flash_attn_bwd( dout, q, k, @@ -3995,11 +4024,8 @@ def backward(ctx, dout): cu_seqlens_kv, ctx.max_seqlen_q, ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.softmax_scale, - causal, - rng_state=rng_state, - **fa_optional_backward_kwargs, + causal=causal, + **fa_backward_kwargs, ) dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c797208e06..a0ebf6faa7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -433,14 +433,14 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s int half_idx); void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, int total_tokens); + const at::Tensor &cu_seqlens, bool lse_packed); at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - int total_tokens); + bool lse_packed); void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool only_second_half); + bool only_second_half, bool lse_packed); void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, const at::Tensor &cu_seqlens, const std::string &first_half, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index b2968a688d..8088a2b8f1 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1464,9 +1464,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s * Support THD format for Context Parallel: softmax_lse related operations **************************************************************************************************/ -template +template __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, - int num_heads, int max_seqlen) { + int num_heads, int total_tokens) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / 2; @@ -1480,12 +1480,18 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + size_t idx, half_idx; + if constexpr (lse_packed) { + idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * total_tokens / 2 + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - size_t idx = row * max_seqlen + col + seq_len; - size_t half_idx = row * max_seqlen / 2 + col; + idx = row * total_tokens + col + seq_len; + half_idx = row * total_tokens / 2 + col; + } Functor::run(lse, half_lse, idx, half_idx); } @@ -1504,32 +1510,53 @@ struct LseCorrectionFunctor { }; void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, int total_tokens) { + const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); - - NVTE_CHECK(lse.dim() == 3); - NVTE_CHECK(lse_per_step.dim() == 3); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch = lse.size(0); - int num_heads = lse.size(1); - int max_seqlen = lse.size(2); + int batch, num_heads, total_tokens; + + if (lse_packed) { + NVTE_CHECK(lse.dim() == 2); + NVTE_CHECK(lse_per_step.dim() == 2); + + batch = cu_seqlens.size(0) - 1; + num_heads = lse.size(0); + total_tokens = lse.size(1); + + NVTE_CHECK(lse_per_step.size(0) == num_heads); + NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2); + } else { + NVTE_CHECK(lse.dim() == 3); + NVTE_CHECK(lse_per_step.dim() == 3); + + batch = lse.size(0); + num_heads = lse.size(1); + total_tokens = lse.size(2); - NVTE_CHECK(lse_per_step.size(0) == batch); - NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2); - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + } constexpr unsigned int block = 256; unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; - thd_lse_kernel - <<>>( - lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, max_seqlen); + if (lse_packed) { + thd_lse_kernel + <<>>( + lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, total_tokens); + } else { + thd_lse_kernel + <<>>( + lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, total_tokens); + } } struct ReadLseFunctor { @@ -1540,29 +1567,51 @@ struct ReadLseFunctor { }; at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - int total_tokens) { + bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); - NVTE_CHECK(lse.dim() == 3); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch = lse.size(0); - int num_heads = lse.size(1); - int max_seqlen = lse.size(2); + int batch, num_heads, total_tokens; + std::vector shape; + + if (lse_packed) { + NVTE_CHECK(lse.dim() == 2); + + batch = cu_seqlens.size(0) - 1; + num_heads = lse.size(0); + total_tokens = lse.size(1); + + shape = {num_heads, total_tokens / 2}; + } else { + NVTE_CHECK(lse.dim() == 3); - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + batch = lse.size(0); + num_heads = lse.size(1); + total_tokens = lse.size(2); + + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + shape = {batch, num_heads, total_tokens / 2}; + } - std::vector shape = {batch, num_heads, max_seqlen / 2}; at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); constexpr unsigned int block = 256; unsigned int grid_x = (total_tokens / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; - thd_lse_kernel - <<>>( - lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, max_seqlen); + if (lse_packed) { + thd_lse_kernel + <<>>( + lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, total_tokens); + } else { + thd_lse_kernel + <<>>( + lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, total_tokens); + } return half_lse; } @@ -1571,10 +1620,10 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ * Support THD format for Context Parallel: Out correction in forward **************************************************************************************************/ -template +template __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, float *lse_per_step, int *cu_seqlens, int batch, - int num_heads, int dim_per_head, int max_seqlen) { + int num_heads, int dim_per_head, int lse_seqlen) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); @@ -1592,11 +1641,16 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { size_t idx, idx_per_step; - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = row * max_seqlen + col + seq_len * only_second_half; - idx_per_step = row * max_seqlen / (only_second_half + 1) + col; + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * lse_seqlen + col + seq_len * only_second_half; + idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; + } float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; @@ -1622,7 +1676,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float template static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens) { + const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type()); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); @@ -1631,17 +1685,30 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ int total_tokens = out.size(0); int num_heads = out.size(1); int dim_per_head = out.size(2); - int batch = lse.size(0); - int max_seqlen = lse.size(2); NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); NVTE_CHECK(out_per_step.size(1) == num_heads); NVTE_CHECK(out_per_step.size(2) == dim_per_head); - NVTE_CHECK(lse.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(0) == batch); - NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1)); - NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + int batch, lse_seqlen; + if (lse_packed) { + batch = cu_seqlens.size(0) - 1; + lse_seqlen = total_tokens; + + NVTE_CHECK(lse.size(0) == num_heads); + NVTE_CHECK(lse.size(1) == lse_seqlen); + NVTE_CHECK(lse_per_step.size(0) == num_heads); + NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1)); + } else { + batch = lse.size(0); + lse_seqlen = lse.size(2); + + NVTE_CHECK(lse.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1)); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + } constexpr int tile = 16; constexpr int block = 512; @@ -1649,39 +1716,53 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ (static_cast(total_tokens) / (only_second_half + 1) * tile + block - 1) / block; dim3 grid = {grid_x, (unsigned int)num_heads}; - thd_out_correction_kernel - <<>>( - out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), - lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, - dim_per_head, max_seqlen); + if (lse_packed) { + thd_out_correction_kernel + <<>>( + out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), + lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, + dim_per_head, lse_seqlen); + } else { + thd_out_correction_kernel + <<>>( + out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), + lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, + dim_per_head, lse_seqlen); + } } void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool only_second_half) { + bool only_second_half, bool lse_packed) { if (only_second_half) { if (out.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::Float) { using dtype = float; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else { NVTE_ERROR("Unsupported dtype of out\n"); } } else { if (out.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else if (out.scalar_type() == at::ScalarType::Float) { using dtype = float; - thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); } else { NVTE_ERROR("Unsupported dtype of out\n"); }