Skip to content

Commit b0c3fe1

Browse files
min-xu-aiflying-x
andauthored
[minor] add a checking around local_state_dict (#1040)
Co-authored-by: Min Xu <[email protected]>
1 parent 16fba4c commit b0c3fe1

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,11 @@ def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
962962
so the resulting state_dict can only be loaded after the Module has been
963963
wrapped with FSDP.
964964
"""
965+
# Check state, specifically, we shouldn't be in SUMMON_FULL_PARAMS since
966+
# that will produce full state, not sharded state.
967+
self.assert_state(
968+
[TrainingState.IDLE, TrainingState.FORWARD, TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]
969+
)
965970
with contextlib.ExitStack() as stack:
966971
# Tell any nested FSDP instances not to auto summon full params.
967972
for module in self.modules(): # includes self
@@ -1025,6 +1030,11 @@ def load_local_state_dict(
10251030
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
10261031
) -> NamedTuple:
10271032
"""Load a local (sharded) state_dict."""
1033+
# Check state, specifically, we shouldn't be in SUMMON_FULL_PARAMS since
1034+
# that will load full state, not sharded state.
1035+
self.assert_state(
1036+
[TrainingState.IDLE, TrainingState.FORWARD, TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]
1037+
)
10281038
with contextlib.ExitStack() as stack:
10291039
# Tell any nested FSDP instances not to auto summon full params.
10301040
for module in self.modules(): # includes self

0 commit comments

Comments
 (0)