Skip to content

Commit

Permalink
fixed logic to remove FSDP sharding
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Nov 21, 2024
1 parent 26f74a5 commit 718c03d
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,44 @@ def mirror_dim(dim, ndims):

def remove_fsdp_specs(pspecs):
fsdp_resource = global_mesh_resource().fsdp_resource
if fsdp_resource is None:
return list(pspecs).copy()

new_pspecs = []
for spec in pspecs:
if spec is None:
new_pspecs.append(None)
elif fsdp_resource not in spec:
new_pspecs.append(spec)

elif isinstance(spec, Iterable) and not isinstance(spec, str):
new_spec = []
for s in spec:
if s != fsdp_resource:
if s == fsdp_resource:
new_spec.append(None)
else:
new_spec.append(s)

if len(new_spec) > 1:
new_pspecs.append(new_spec)
elif len(new_spec) == 1:
new_pspecs.append(new_spec[0])
else:
new_pspecs.append(None)

elif isinstance(spec, str):
if spec == fsdp_resource:
new_pspecs.append(None)
else:
new_pspecs.append(spec)

else:
new_pspecs.append(None)
new_pspecs.append(spec)

assert len(new_pspecs) == len(pspecs), (
"Length of partition specs changed when removing FSDP sharding!\n"
+ f"Original: {pspecs}\n"
+ f"Filtered: {new_pspecs}\n"
)

return new_pspecs


Expand Down

0 comments on commit 718c03d

Please sign in to comment.