diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 62ffec2cd6..5f8357a01b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2534,6 +2534,8 @@ def backward(ctx, dout): causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type + + seq_dim = None if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] @@ -2580,7 +2582,6 @@ def backward(ctx, dout): fused_attn_qkv_dtype = None fused_attn_dqkv_dtype = None amax_per_step = None - seq_dim = None dout_fp8_dtype = None if ctx.fp8: if ctx.use_fused_attention: