Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci

Signed-off-by: eljandoubi <[email protected]>
  • Loading branch information
pre-commit-ci[bot] authored and eljandoubi committed Oct 16, 2024
1 parent 5e6cf35 commit 617e1de
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
3 changes: 2 additions & 1 deletion tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def run_dpa_with_cp(
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [
x.reshape(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
x.reshape(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
for x in [q_, k_, v_, dout_]
]
elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,9 @@ def _run_ref_mha_f16(dtype, config, backend):
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = (
torch.load("out_grad.pt").to(device="cuda").reshape(config.batch_size, config.max_seqlen_q, -1)
torch.load("out_grad.pt")
.to(device="cuda")
.reshape(config.batch_size, config.max_seqlen_q, -1)
)

_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ def forward(
value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)

# change view [b * np, sq, sk]
attention_probs = attention_probs.reshape(output_size[0] * output_size[1], output_size[2], -1)
attention_probs = attention_probs.reshape(
output_size[0] * output_size[1], output_size[2], -1
)

# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
Expand Down
18 changes: 12 additions & 6 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,9 @@ def flash_attn_a2a_communicate(
)
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
# or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
a2a_outputs[i - 2] = x.reshape(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
a2a_outputs[i - 2] = x.reshape(
*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]
)
if i < len(a2a_inputs):
x = a2a_inputs[i]
# [b, s, np, hn] -> [b, s, cp, np//cp, hn]
Expand Down Expand Up @@ -1786,7 +1788,9 @@ def forward(
if causal:
if qkv_format == "bshd":
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.reshape(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]]
q, k, v = [
x.reshape(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]
]
elif qkv_format == "sbhd":
# [s, b, np, hn] -> [2, s//2, b, np, hn]
q, k, v = [x.reshape(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
Expand Down Expand Up @@ -4914,7 +4918,9 @@ def forward(
value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)

# change view [b * np, sq, sk]
attention_probs = attention_probs.reshape(output_size[0] * output_size[1], output_size[2], -1)
attention_probs = attention_probs.reshape(
output_size[0] * output_size[1], output_size[2], -1
)

# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
Expand Down Expand Up @@ -5985,9 +5991,9 @@ def forward(
"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, "
f"but found {qkv_layout}."
)
q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).reshape(
q.shape
)
q_fp8 = cast_to_fp8(
q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
).reshape(q.shape)
kv_c = kv.reshape(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = cast_to_fp8(
kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward
Expand Down

0 comments on commit 617e1de

Please sign in to comment.