Skip to content

Commit 2231cc8

Browse files
committed
TP communication overlap: enable the overlap between GEMM chunk at Hopper BF16
Signed-off-by: Sangkug Lym <[email protected]>
1 parent 095b27d commit 2231cc8

File tree

1 file changed

+4
-1
lines changed
  • transformer_engine/pytorch/cpp_extensions

1 file changed

+4
-1
lines changed

transformer_engine/pytorch/cpp_extensions/gemm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,13 @@ def gemm(
284284
assert (
285285
extra_output_tensor is not None
286286
), "SPLIT_PIPELINED_RS requires extra output tensor"
287+
# Disable the overlap between GEMM chunks at ampere and below
288+
major, _ = torch.cuda.get_device_capability()
289+
overlap_gemm_chunks = True if major >= 9 else False
287290
args = tuple(
288291
args
289292
+ (
290-
False,
293+
overlap_gemm_chunks,
291294
extra_output_tensor,
292295
)
293296
)

0 commit comments

Comments
 (0)