@@ -962,6 +962,11 @@ def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
962
962
so the resulting state_dict can only be loaded after the Module has been
963
963
wrapped with FSDP.
964
964
"""
965
+ # Check state, specifically, we shouldn't be in SUMMON_FULL_PARAMS since
966
+ # that will produce full state, not sharded state.
967
+ self .assert_state (
968
+ [TrainingState .IDLE , TrainingState .FORWARD , TrainingState .BACKWARD_PRE , TrainingState .BACKWARD_POST ]
969
+ )
965
970
with contextlib .ExitStack () as stack :
966
971
# Tell any nested FSDP instances not to auto summon full params.
967
972
for module in self .modules (): # includes self
@@ -1025,6 +1030,11 @@ def load_local_state_dict(
1025
1030
self , state_dict : Union [Dict [str , torch .Tensor ], "OrderedDict[str, torch.Tensor]" ], strict : bool = True
1026
1031
) -> NamedTuple :
1027
1032
"""Load a local (sharded) state_dict."""
1033
+ # Check state, specifically, we shouldn't be in SUMMON_FULL_PARAMS since
1034
+ # that will load full state, not sharded state.
1035
+ self .assert_state (
1036
+ [TrainingState .IDLE , TrainingState .FORWARD , TrainingState .BACKWARD_PRE , TrainingState .BACKWARD_POST ]
1037
+ )
1028
1038
with contextlib .ExitStack () as stack :
1029
1039
# Tell any nested FSDP instances not to auto summon full params.
1030
1040
for module in self .modules (): # includes self
0 commit comments