Skip to content

Commit 71f696b

Browse files
committed
make sure Fp8 weight buffers are sharded at the end of the backward pass and gathered before forward
Signed-off-by: Alp Dener <[email protected]>
1 parent 23c0cd5 commit 71f696b

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def forward(
172172
weight_fp8 = weight
173173
weight_t_fp8 = None
174174
elif update_fp8_weights:
175+
# Gather Fp8 weight buffers if needed
176+
if fsdp_group is not None and weight_fp8._data.shape != weight.data.shape:
177+
_fsdp_gather_tensors(fsdp_group, [weight.data.shape], weight_fp8)
175178
# Need to cast weights to FP8
176179
weight_fp8 = Float8Tensor(
177180
data=weight_fp8._data,
@@ -181,6 +184,12 @@ def forward(
181184
if (is_grad_enabled
182185
or (is_fp8_activation_recompute_enabled()
183186
and not in_fp8_activation_recompute_phase())):
187+
# Gather Fp8 transposed-weight buffers if needed
188+
if (fsdp_group is not None
189+
and weight_t_fp8._data.shape != reversed(weight.data.shape)):
190+
_fsdp_gather_tensors(fsdp_group,
191+
[tuple(reversed(weight.data.shape))],
192+
weight_t_fp8)
184193
tex.fp8_cast_transpose_fused(
185194
weight,
186195
fp8_meta["scaling_fwd"],
@@ -261,12 +270,12 @@ def forward(
261270
rsigma.activation_offloading = True
262271
ln_out.activation_offloading = True
263272

273+
# Scatter intermediate/activation tensors saved for the backward pass
264274
ctx.fsdp_group = fsdp_group
265275
ctx.fsdp_shapes = _fsdp_scatter_tensors(
266276
fsdp_group,
267277
mu,
268278
rsigma,
269-
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
270279
weight_t_fp8,
271280
ln_out
272281
)
@@ -314,6 +323,9 @@ def forward(
314323
# [*, in_features] -> [*, out_features] except first dimension changes for SP
315324
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
316325

326+
# Scatter Fp8 weight buffers
327+
_fsdp_scatter_tensors(fsdp_group, weight_fp8, weight_fp8)
328+
317329
if return_layernorm_output:
318330
return out, ln_out_return.view_as(inp)
319331
return out
@@ -338,6 +350,7 @@ def backward(
338350
fwd_scale_inverses,
339351
) = ctx.saved_tensors
340352

353+
# Gather intermediate/activation tensors if needed
341354
_fsdp_gather_tensors(
342355
ctx.fsdp_group,
343356
ctx.fsdp_shapes,
@@ -575,6 +588,8 @@ def backward(
575588
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
576589
)
577590
dbeta = None
591+
clear_tensor_data(mu)
592+
clear_tensor_data(rsigma)
578593

579594
if not ctx.use_bias:
580595
grad_bias = None
@@ -600,6 +615,9 @@ def backward(
600615
else:
601616
wgrad = None
602617

618+
# Scatter fp8 transposed-weight buffers
619+
_fsdp_scatter_tensors(ctx.fsdp_group, weight_t_fp8)
620+
603621
return (
604622
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
605623
dgamma,

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,17 @@ def forward(
216216
fc1_weight_t_fp8 = None
217217
fc2_weight_t_fp8 = None
218218
elif update_fp8_weights:
219+
# Gather Fp8 weight buffers if needed
220+
if fsdp_group is not None:
221+
weights_to_gather = []
222+
gather_shapes = []
223+
if fc1_weight_fp8._data.shape != fc1_weight.data.shape:
224+
weights_to_gather.append(fc1_weight_fp8)
225+
gather_shapes.append(fc1_weight.data.shape)
226+
if fc2_weight_fp8._data.shape != fc2_weight.data.shape:
227+
weights_to_gather.append(fc2_weight_fp8)
228+
gather_shapes.append(fc2_weight.data.shape)
229+
_fsdp_gather_tensors(fsdp_group, gather_shapes, weights_to_gather)
219230
# Need to cast weights to FP8
220231
fc1_weight_fp8 = Float8Tensor(
221232
data=fc1_weight_fp8._data,
@@ -230,6 +241,17 @@ def forward(
230241
if (is_grad_enabled
231242
or (is_fp8_activation_recompute_enabled()
232243
and not in_fp8_activation_recompute_phase())):
244+
# Gather Fp8 transposed-weight buffers if needed
245+
if fsdp_group is not None:
246+
weights_to_gather = []
247+
gather_shapes = []
248+
if fc1_weight_t_fp8._data.shape != reversed(fc1_weight.data.shape):
249+
weights_to_gather.append(fc1_weight_t_fp8)
250+
gather_shapes.append(tuple(reversed(fc1_weight.data.shape)))
251+
if fc2_weight_t_fp8._data.shape != reversed(fc2_weight.data.shape):
252+
weights_to_gather.append(fc2_weight_t_fp8)
253+
gather_shapes.append(tuple(reversed(fc2_weight.data.shape)))
254+
_fsdp_gather_tensors(fsdp_group, gather_shapes, weights_to_gather)
233255
# Fused cast-transpose kernels
234256
tex.fp8_cast_transpose_fused(
235257
fc1_weight,
@@ -473,6 +495,10 @@ def forward(
473495
fc1_out.activation_offloading = True
474496
gelu_out.activation_offloading = True
475497

498+
# Scatter Fp8 weight buffers
499+
_fsdp_scatter_tensors(fsdp_group, fc1_weight_fp8, fc2_weight_fp8)
500+
501+
# Scatter intermediate/activation tensors saved for the backward pass
476502
ctx.fsdp_group = fsdp_group
477503
ctx.fsdp_shapes = _fsdp_scatter_tensors(
478504
fsdp_group,
@@ -1000,6 +1026,8 @@ def backward(
10001026
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
10011027
)
10021028
dbeta = None
1029+
clear_tensor_data(mu)
1030+
clear_tensor_data(rsigma)
10031031

10041032
if fc1_weight.requires_grad:
10051033
# Handle custom DDP from mcore.
@@ -1043,6 +1071,10 @@ def backward(
10431071
else:
10441072
fc2_wgrad = None
10451073

1074+
# Scatter Fp8 tranposed-weight buffers
1075+
_fsdp_scatter_tensors(ctx.fsdp_group, fc1_weight_t_fp8)
1076+
_fsdp_scatter_tensors(ctx.fsdp_group, fc2_weight_t_fp8)
1077+
10461078
return (
10471079
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
10481080
dgamma,

transformer_engine/pytorch/module/linear.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def forward(
154154
weight_fp8 = weight
155155
weight_t_fp8 = None
156156
elif update_fp8_weights:
157+
# Gather Fp8 weight buffers if needed
158+
if fsdp_group is not None and weight_fp8._data.shape != weight.data.shape:
159+
_fsdp_gather_tensors(fsdp_group, [weight.data.shape], weight_fp8)
157160
# Need to cast weights to FP8
158161
weight_fp8 = Float8Tensor(
159162
data=weight_fp8._data,
@@ -163,6 +166,12 @@ def forward(
163166
if (is_grad_enabled
164167
or (is_fp8_activation_recompute_enabled()
165168
and not in_fp8_activation_recompute_phase())):
169+
# Gather Fp8 transposed-weight buffers if needed
170+
if (fsdp_group is not None
171+
and weight_t_fp8._data.shape != reversed(weight.data.shape)):
172+
_fsdp_gather_tensors(fsdp_group,
173+
[tuple(reversed(weight.data.shape))],
174+
weight_t_fp8)
166175
fp8_cast_transpose_fused(
167176
weight,
168177
fp8_meta["scaling_fwd"],
@@ -290,13 +299,12 @@ def forward(
290299
if saved_inputmat is not None:
291300
saved_inputmat.activation_offloading = True
292301

293-
fwd_scale_inverses = fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None
302+
# Scatter intermediate/activation tensors saved for the backward pass
294303
ctx.fsdp_group = fsdp_group
295304
ctx.fsdp_shapes = _fsdp_scatter_tensors(
296305
fsdp_group,
297306
saved_inputmat, # None if fp8 == False
298307
saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled
299-
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
300308
weight_t_fp8 if fp8 else None,
301309
)
302310

@@ -306,7 +314,7 @@ def forward(
306314
weight,
307315
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
308316
weight_t_fp8 if fp8 else None,
309-
fwd_scale_inverses
317+
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None
310318
)
311319

312320
ctx.activation_dtype = activation_dtype
@@ -335,6 +343,9 @@ def forward(
335343
elif parallel_mode == "row" and tensor_parallel:
336344
out, _ = allreduce(out, tp_group)
337345

346+
# Scatter Fp8 weight buffers
347+
_fsdp_scatter_tensors(fsdp_group, weight_fp8)
348+
338349
# [*, in_features] -> [*, out_features] except first dimension changes for SP
339350
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
340351

@@ -355,11 +366,11 @@ def backward(
355366
fwd_scale_inverses,
356367
) = ctx.saved_tensors
357368

369+
# Gather intermediate/activation tensors if needed
358370
_fsdp_gather_tensors(ctx.fsdp_group,
359371
ctx.fsdp_shapes,
360372
inputmat,
361373
inputmat_t,
362-
main_grad,
363374
weight_t_fp8)
364375

365376
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
@@ -539,6 +550,9 @@ def backward(
539550
else:
540551
wgrad = None
541552

553+
# Scatter fp8 transposed-weight buffers
554+
_fsdp_scatter_tensors(ctx.fsdp_group, weight_t_fp8)
555+
542556
return (
543557
wgrad,
544558
None,

0 commit comments

Comments
 (0)