We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2231cc8 commit 4bb1812Copy full SHA for 4bb1812
transformer_engine/pytorch/cpp_extensions/gemm.py
@@ -286,7 +286,7 @@ def gemm(
286
), "SPLIT_PIPELINED_RS requires extra output tensor"
287
# Disable the overlap between GEMM chunks at ampere and below
288
major, _ = torch.cuda.get_device_capability()
289
- overlap_gemm_chunks = True if major >= 9 else False
+ overlap_gemm_chunks = major >= 9
290
args = tuple(
291
args
292
+ (
0 commit comments