26
26
import torch
27
27
import torch .distributed as dist
28
28
import torch .nn as nn
29
+ from torch .autograd import Variable
29
30
30
31
from olmo_core .distributed .tensors import ShardedFlatParameter
31
32
from olmo_core .stream import Stream
@@ -322,7 +323,7 @@ def clip_grad_norm_(self, max_norm: float, norm_type: float = 2.0) -> torch.Tens
322
323
nonsharded_params : Set [nn .Parameter ] = set ()
323
324
grads : List [torch .Tensor ] = []
324
325
for param in self .parameters ():
325
- if param .grad is None :
326
+ if param .grad is None or param . grad . numel () == 0 :
326
327
continue
327
328
328
329
if isinstance (param , ShardedFlatParameter ):
@@ -394,7 +395,11 @@ def _lazy_init(self):
394
395
self .state .forward_execution_order .append (self )
395
396
return
396
397
397
- log .debug ("Completing lazy initialization from root FSDP for %s..." , self .module .__class__ .__name__ )
398
+ log .debug (
399
+ "Completing lazy initialization from root FSDP for %s (%s)..." ,
400
+ self .module .__class__ .__name__ ,
401
+ id (self .module ),
402
+ )
398
403
399
404
# Initialize streams.
400
405
self .state .compute_stream = Stream .default (self .device )
@@ -494,7 +499,7 @@ def _shard(self):
494
499
495
500
This should only be called once at initialization.
496
501
"""
497
- log .debug ("Sharding %s..." , self .module .__class__ .__name__ )
502
+ log .debug ("Sharding %s (%s) ..." , self .module .__class__ .__name__ , id ( self . module ) )
498
503
499
504
params_with_grads : List [nn .Parameter ] = []
500
505
params_with_grads_fqns : List [str ] = []
@@ -568,7 +573,7 @@ def _unshard(
568
573
569
574
kwargs = dict (cast = cast , set_grads = set_grads , recurse = recurse , rank0_only = rank0_only )
570
575
571
- log .debug ("Unsharding %s..." , self .module .__class__ .__name__ )
576
+ log .debug ("Unsharding %s (%s) ..." , self .module .__class__ .__name__ , id ( self . module ) )
572
577
self .state .params_prefetched = True
573
578
574
579
# NOTE: `unshard_stream` should wait on current stream (usually `compute_stream` / `default_stream`)
@@ -600,7 +605,11 @@ def _unshard(
600
605
def _prefetch (self , prefetch_from : deque [FSDP ], ** kwargs ):
601
606
for module in self ._deque_from (prefetch_from ):
602
607
log .debug (
603
- "Prefetching %s from %s..." , module .module .__class__ .__name__ , self .module .__class__ .__name__
608
+ "Prefetching %s (%s) from %s (%s)..." ,
609
+ module .module .__class__ .__name__ ,
610
+ id (module .module ),
611
+ self .module .__class__ .__name__ ,
612
+ id (self .module ),
604
613
)
605
614
module ._unshard (** kwargs )
606
615
@@ -611,7 +620,7 @@ def _reshard(self, writeback: bool = False, recurse: bool = False):
611
620
"""
612
621
kwargs = dict (writeback = writeback , recurse = recurse )
613
622
614
- log .debug ("Resharding %s..." , self .module .__class__ .__name__ )
623
+ log .debug ("Resharding %s (%s) ..." , self .module .__class__ .__name__ , id ( self . module ) )
615
624
self .state .params_prefetched = False
616
625
617
626
for handle in self .state .flat_param_handles :
@@ -637,7 +646,7 @@ def _reduce_scatter_grads(self):
637
646
638
647
grad_reduce_dtype : Optional [torch .dtype ] = self .precision .reduce_dtype or self .precision .param_dtype
639
648
with self .state .reduce_stream (wait_stream = self .state .current_stream ):
640
- log .debug ("Reduce-scattering grads for %s" , self .module .__class__ .__name__ )
649
+ log .debug ("Reduce-scattering grads for %s (%s) " , self .module .__class__ .__name__ , id ( self . module ) )
641
650
for handle in self .state .flat_param_handles :
642
651
handle .reduce_scatter_grads_ (grad_reduce_dtype = grad_reduce_dtype )
643
652
@@ -659,13 +668,16 @@ def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None
659
668
@torch .no_grad ()
660
669
def _pre_backward_hook (self , * unused : Any ):
661
670
del unused
662
- log .debug ("Running pre-backward hook for %s..." , self .module .__class__ .__name__ )
671
+ log .debug ("Running pre-backward hook for %s (%s) ..." , self .module .__class__ .__name__ , id ( self . module ) )
663
672
664
673
# Remove all pre backward hooks for this FSDP instance since they all do the same thing.
665
674
for handle in self .state .pre_backward_hook_handles :
666
675
handle .remove ()
667
676
self .state .pre_backward_hook_handles .clear ()
668
677
678
+ if self .is_root :
679
+ self ._register_post_backward_final_hook ()
680
+
669
681
# Unshard parameters in place.
670
682
self ._unshard (set_grads = True )
671
683
@@ -684,10 +696,12 @@ def _register_pre_backward_hook(self, x: torch.Tensor):
684
696
self .state .pre_backward_hook_handles .append (handle )
685
697
686
698
def _register_pre_backward_hooks (self , output : Any ):
687
- log .debug ("Registering pre-backward hooks for %s..." , self .module .__class__ .__name__ )
699
+ log .debug ("Registering pre-backward hooks for %s (%s) ..." , self .module .__class__ .__name__ , id ( self . module ) )
688
700
# Clear existing hooks if there are any.
689
701
if self .state .pre_backward_hook_handles :
690
- log .debug ("Removing old pre-backward hooks for %s..." , self .module .__class__ .__name__ )
702
+ log .debug (
703
+ "Removing old pre-backward hooks for %s (%s)..." , self .module .__class__ .__name__ , id (self .module )
704
+ )
691
705
for handle in self .state .pre_backward_hook_handles :
692
706
handle .remove ()
693
707
self .state .pre_backward_hook_handles .clear ()
@@ -699,29 +713,19 @@ def _register_pre_backward_hooks(self, output: Any):
699
713
@torch .no_grad ()
700
714
def _post_backward_hook (self , param_name : str , * unused : Any ):
701
715
del unused
702
- log .debug ("Running post-backward hook for %s.%s..." , self .module .__class__ .__name__ , param_name )
703
716
self .state .post_backward_hook_handles .pop (param_name ).remove ()
704
717
705
718
# If there are still more handles then there are still more post-backward hooks to be ran
706
719
# in the current FSDP node. Only the last handle should do the work.
707
720
if self .state .post_backward_hook_handles :
708
721
return
709
722
723
+ log .debug ("Running post-backward hook for %s (%s)" , self .module .__class__ .__name__ , id (self .module ))
724
+
710
725
# NOTE: reshard *before* reducing grads to correctly handle precision settings.
711
726
self ._reshard ()
712
727
self ._reduce_scatter_grads ()
713
728
714
- # The root FSDP instance needs to do some final cleanup.
715
- if not self .is_root :
716
- return
717
-
718
- # Mark backward execution order as finalized.
719
- self .state .backward_execution_order_finalized = True
720
-
721
- # Wait for unsharding and reducing streams to complete so the model is not left in a bad
722
- # state before grad clipping, optimizer step, or whatever else.
723
- self .state .current_stream .wait_stream (self .state .reduce_stream )
724
-
725
729
def _register_post_backward_hook (self , param_name : str , param : ShardedFlatParameter ):
726
730
# Force creation of a `grad_fn` in order to register a hook that will run *after* this param's
727
731
# backward pass.
@@ -733,13 +737,42 @@ def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParame
733
737
self .state .post_backward_hook_handles [param_name ] = handle
734
738
735
739
def _register_post_backward_hooks (self ):
736
- log .debug ("Registering post-backward hooks for %s..." , self .module .__class__ .__name__ )
740
+ log .debug (
741
+ "Registering post-backward hooks for %s (%s)..." , self .module .__class__ .__name__ , id (self .module )
742
+ )
737
743
# Clear existing hooks if there are any.
738
744
if self .state .post_backward_hook_handles :
739
- log .debug ("Removing old post-backward hooks for %s..." , self .module .__class__ .__name__ )
745
+ log .debug (
746
+ "Removing old post-backward hooks for %s (%s)..." , self .module .__class__ .__name__ , id (self .module )
747
+ )
740
748
for handle in self .state .post_backward_hook_handles .values ():
741
749
handle .remove ()
742
750
self .state .post_backward_hook_handles .clear ()
743
751
for param_name , param in self ._managed_named_parameters ():
744
752
if param .requires_grad :
745
753
self ._register_post_backward_hook (param_name , param )
754
+
755
+ @torch .no_grad ()
756
+ def _post_backward_final_hook (self ):
757
+ if not self .is_root :
758
+ return
759
+
760
+ log .debug ("Running post-backward final hook for %s (%s)" , self .module .__class__ .__name__ , id (self .module ))
761
+
762
+ # Mark backward execution order as finalized.
763
+ self .state .backward_execution_order_finalized = True
764
+ for child in self ._fsdp_children (recurse = True ):
765
+ child .state .backward_execution_order_finalized = True
766
+
767
+ # Wait for unsharding and reducing streams to complete so the model is not left in a bad
768
+ # state before grad clipping, optimizer step, or whatever else.
769
+ self .state .current_stream .wait_stream (self .state .reduce_stream )
770
+
771
+ def _register_post_backward_final_hook (self ):
772
+ if not self .is_root :
773
+ return
774
+
775
+ log .debug (
776
+ "Registering post-backward final hook for %s (%s)..." , self .module .__class__ .__name__ , id (self .module )
777
+ )
778
+ Variable ._execution_engine .queue_callback (self ._post_backward_final_hook )
0 commit comments