@@ -61,11 +61,6 @@ class FlatParamHandle:
61
61
the same shape as the sharded version of ``params_data``.
62
62
"""
63
63
64
- params_sharded_grad_tmp : Optional [torch .Tensor ] = None
65
- """
66
- Temporary storage for the local consolidated sharded grads during the reduce-scatter.
67
- """
68
-
69
64
process_group : Optional [dist .ProcessGroup ] = None
70
65
71
66
device : Optional [torch .device ] = None
@@ -254,7 +249,8 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
254
249
255
250
# Cast sharded ``params_data`` to ``dtype``.
256
251
if dtype is not None :
257
- self .params_sharded_data_lp = self .params_data .sharded_data .to (dtype )
252
+ self .params_sharded_data_lp = self .params_data .sharded_chunk (all_params_unsharded_data )
253
+ self .params_sharded_data_lp .copy_ (self .params_data .sharded_data )
258
254
259
255
# Initialize unsharded, padded gradient.
260
256
if set_grads and self .params_unsharded_grad is None :
@@ -364,10 +360,6 @@ def pre_reduce_scatter_grads_(
364
360
Stream .current (self .device ).record_for (self .params_unsharded_grad )
365
361
self .params_unsharded_grad = self .params_unsharded_grad .to (dtype = grad_reduce_dtype )
366
362
367
- self .params_sharded_grad_tmp = torch .empty (
368
- self .params_data .sharded_shape , dtype = self .params_unsharded_grad .dtype , device = self .device
369
- )
370
-
371
363
def reduce_scatter_grads_ (
372
364
self , grad_dtype : Optional [torch .dtype ] = None , grad_reduce_dtype : Optional [torch .dtype ] = None
373
365
):
@@ -380,11 +372,8 @@ def reduce_scatter_grads_(
380
372
381
373
if not self ._ran_pre_reduce_scatter_grads :
382
374
self .pre_reduce_scatter_grads_ (grad_dtype = grad_dtype , grad_reduce_dtype = grad_reduce_dtype )
383
- assert self .params_sharded_grad_tmp is not None
384
375
else :
385
- assert self .params_sharded_grad_tmp is not None
386
376
Stream .current (self .device ).record_for (self .params_unsharded_grad )
387
- Stream .current (self .device ).record_for (self .params_sharded_grad_tmp )
388
377
389
378
self ._ran_pre_reduce_scatter_grads = False
390
379
@@ -397,23 +386,40 @@ def reduce_scatter_grads_(
397
386
if dist .get_backend () == dist .Backend .NCCL :
398
387
# Get chunks corresponding to each rank.
399
388
grad_chunks = self .params_data .chunk_unsharded (self .params_unsharded_grad )
400
- dist .reduce_scatter (self .params_sharded_grad_tmp , grad_chunks , group = self .process_group )
389
+ dist .reduce_scatter (
390
+ grad_chunks [get_rank (group = self .process_group )], grad_chunks , group = self .process_group
391
+ )
401
392
else :
402
393
dist .all_reduce (self .params_unsharded_grad , group = self .process_group )
403
- self .params_sharded_grad_tmp .copy_ (self .params_data .sharded_chunk (self .params_unsharded_grad ))
404
394
405
- # Deallocate the unsharded padded grad.
406
- # NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make
407
- # sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes.
408
- Stream .current (self .device ).record_for (self .params_unsharded_grad )
409
- self .params_unsharded_grad = None
395
+ def post_reduce_scatter_grads_ (
396
+ self , grad_dtype : Optional [torch .dtype ] = None , grad_reduce_dtype : Optional [torch .dtype ] = None
397
+ ):
398
+ """
399
+ Finalize sharded gradients after the reduce-scatter.
400
+ """
401
+ grad_dtype = grad_dtype or self .params_data .dtype
402
+ grad_reduce_dtype = grad_reduce_dtype or grad_dtype
403
+
404
+ assert self .params_unsharded_grad is not None
405
+ new_sharded_grad = self .params_data .sharded_chunk (self .params_unsharded_grad )
410
406
411
- # Cast the reduce-scatter target to the right dtype, potentially accumulating it into
412
- # the existing gradient.
407
+ # Cast the new sharded gradient to the right dtype, potentially accumulating it into
408
+ # the existing sharded gradient.
413
409
if self .params_sharded_grad is None :
414
- self .params_sharded_grad = self .params_sharded_grad_tmp .to (grad_dtype )
410
+ if new_sharded_grad .dtype == grad_dtype :
411
+ self .params_sharded_grad = new_sharded_grad .clone ()
412
+ else :
413
+ self .params_sharded_grad = new_sharded_grad .to (grad_dtype )
415
414
else :
416
- self .params_sharded_grad .add_ (self .params_sharded_grad_tmp )
415
+ self .params_sharded_grad .add_ (new_sharded_grad )
416
+
417
+ # Deallocate the unsharded padded grad.
418
+ # NOTE: Since we're potentially using a separate stream here, we need to make
419
+ # sure `params_unsharded_grad` is not deallocated before this finishes.
420
+ Stream .current (self .device ).record_for (self .params_unsharded_grad )
421
+ self .params_unsharded_grad = None
422
+ del new_sharded_grad
417
423
418
424
# At this point each param will be sharded again, and we set the grad for each param as a view
419
425
# into the sharded grad.
0 commit comments