Skip to content

Commit 8fe3942

Browse files
committed
added missing copy of AG+GEMM input into comm buffer
Signed-off-by: Alp Dener <[email protected]>
1 parent ec2d5ae commit 8fe3942

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

examples/jax/comm_gemm_overlap/comm_gemm_overlap.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from jax.experimental import mesh_utils
1515

1616
import transformer_engine.jax as te
17-
from transformer_engine.jax.cpp_extensions import gemm_impl
17+
from transformer_engine.jax.cpp_extensions import gemm_impl, copy_into_overlap_buffer
1818
from transformer_engine.jax.gemm import (
1919
initialize_comm_gemm_overlaps,
2020
destroy_comm_gemm_overlaps,
@@ -124,14 +124,15 @@
124124
if myrank == 0:
125125
print(
126126
f"{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n"
127-
+ f"{myrank}: LHS sharding: {lhs.sharding}\n"
128-
+ f"{myrank}: RHS sharding: {rhs.sharding}\n",
127+
+ f"{myrank}: LHS sharding: {lhs.sharding.spec}\n"
128+
+ f"{myrank}: RHS sharding: {rhs.sharding.spec}\n",
129129
flush=True,
130130
)
131131

132132

133133
@jax.jit
134134
def te_gemm(A, B):
135+
copy_into_overlap_buffer(A, overlap_name, True)
135136
return gemm_impl(
136137
A,
137138
jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding),
@@ -145,10 +146,9 @@ def te_gemm(A, B):
145146

146147
if myrank == 0:
147148
print(
148-
f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUTS:\n"
149-
+ f"{myrank}: GEMM output: {output.shape} | {output.sharding}\n"
150-
+ f"{myrank}: {'Gathered LHS' if args.comm_type == 'AG' else 'Scattered output:'}: "
151-
+ f"{extra_out.shape} | {extra_out.sharding}\n",
149+
f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT "
150+
+ f"{output.shape}\n"
151+
+ f"{myrank}: Sharding: {output.sharding.spec}\n",
152152
flush=True,
153153
)
154154

0 commit comments

Comments
 (0)