14
14
from jax .experimental import mesh_utils
15
15
16
16
import transformer_engine .jax as te
17
- from transformer_engine .jax .cpp_extensions import gemm_impl
17
+ from transformer_engine .jax .cpp_extensions import gemm_impl , copy_into_overlap_buffer
18
18
from transformer_engine .jax .gemm import (
19
19
initialize_comm_gemm_overlaps ,
20
20
destroy_comm_gemm_overlaps ,
124
124
if myrank == 0 :
125
125
print (
126
126
f"{ myrank } : INPUTS { lhs .shape } x { rhs .shape } \n "
127
- + f"{ myrank } : LHS sharding: { lhs .sharding } \n "
128
- + f"{ myrank } : RHS sharding: { rhs .sharding } \n " ,
127
+ + f"{ myrank } : LHS sharding: { lhs .sharding . spec } \n "
128
+ + f"{ myrank } : RHS sharding: { rhs .sharding . spec } \n " ,
129
129
flush = True ,
130
130
)
131
131
132
132
133
133
@jax .jit
134
134
def te_gemm (A , B ):
135
+ copy_into_overlap_buffer (A , overlap_name , True )
135
136
return gemm_impl (
136
137
A ,
137
138
jax .lax .with_sharding_constraint (B , weight_no_fsdp_sharding ),
@@ -145,10 +146,9 @@ def te_gemm(A, B):
145
146
146
147
if myrank == 0 :
147
148
print (
148
- f"{ myrank } : { 'AG -> GEMM' if args .comm_type == 'AG' else 'GEMM -> RS' } OUTPUTS:\n "
149
- + f"{ myrank } : GEMM output: { output .shape } | { output .sharding } \n "
150
- + f"{ myrank } : { 'Gathered LHS' if args .comm_type == 'AG' else 'Scattered output:' } : "
151
- + f"{ extra_out .shape } | { extra_out .sharding } \n " ,
149
+ f"{ myrank } : { 'AG -> GEMM' if args .comm_type == 'AG' else 'GEMM -> RS' } OUTPUT "
150
+ + f"{ output .shape } \n "
151
+ + f"{ myrank } : Sharding: { output .sharding .spec } \n " ,
152
152
flush = True ,
153
153
)
154
154
0 commit comments