diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6b153fd3c1..28c1b45ffa 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2528,12 +2528,13 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] - cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + (*saved_tensors,) = ctx.saved_tensors + (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6] + (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8] + cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -3577,11 +3578,12 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] - cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] - out_per_step = ctx.saved_tensors[7:9] - softmax_lse_per_step = ctx.saved_tensors[9:11] - rng_states = ctx.saved_tensors[11:13] + (*saved_tensors,) = ctx.saved_tensors + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + cu_seqlens_kv_per_step = saved_tensors[5:7] + out_per_step = saved_tensors[7:9] + softmax_lse_per_step = saved_tensors[9:11] + rng_states = saved_tensors[11:13] kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step @@ -4056,12 +4058,11 @@ def backward(ctx, dout): # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) - q, k, v, out = ctx.saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ - 4:8 - ] - fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] - aux_ctx_tensors = ctx.saved_tensors[10:] + (*saved_tensors,) = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8] + fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10] + aux_ctx_tensors = saved_tensors[10:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type