Skip to content

Commit

Permalink
Fix seq_dim in CP implementation (#1264)
Browse files Browse the repository at this point in the history
fix seq_dim in CP implementation

Signed-off-by: Xiaowei Ren <[email protected]>
  • Loading branch information
xrennvidia authored Oct 17, 2024
1 parent 12f30ea commit a488b8b
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,6 +2534,8 @@ def backward(ctx, dout):

causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type

seq_dim = None
if ctx.qkv_format in ["bshd", "sbhd"]:
seq_dim = ctx.qkv_format.index("s")
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
Expand Down Expand Up @@ -2580,7 +2582,6 @@ def backward(ctx, dout):
fused_attn_qkv_dtype = None
fused_attn_dqkv_dtype = None
amax_per_step = None
seq_dim = None
dout_fp8_dtype = None
if ctx.fp8:
if ctx.use_fused_attention:
Expand Down

0 comments on commit a488b8b

Please sign in to comment.