Skip to content

Commit 4b126c7

Browse files
min-xu-aiflying-x
andauthored
[feat] support a context for loading state_dict for FSDP (#1065)
* [fix]: add a context for supporting state_dict from a non-FSDP parent module * formatting Co-authored-by: Min Xu <[email protected]>
1 parent 3cc7fa8 commit 4b126c7

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

fairscale/nn/data_parallel/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
from typing import List
77

8-
from .fully_sharded_data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState, auto_wrap_bn
8+
from .fully_sharded_data_parallel import (
9+
FullyShardedDataParallel,
10+
OffloadConfig,
11+
TrainingState,
12+
auto_wrap_bn,
13+
no_pre_load_state_dict_hook,
14+
)
915
from .sharded_ddp import ShardedDataParallel
1016

1117
__all__: List[str] = []

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device
5252
from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer
5353
from fairscale.internal.state_dict import replace_by_prefix_
54-
from fairscale.nn.misc import FlattenParamsWrapper
54+
from fairscale.nn.misc import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
5555
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
5656

5757
from . import fsdp_optim_utils as ou
@@ -762,6 +762,7 @@ def _shard_parameters_(self) -> None:
762762
self.numel_padded_per_param.append(0)
763763
continue
764764
p._is_sharded = True
765+
# TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed?
765766

766767
# Replace p.data with the relevant shard.
767768
orig_data = p.data
@@ -2581,10 +2582,25 @@ def apply_to_tensor(obj: torch.Tensor) -> torch.Tensor:
25812582
return state_dict
25822583

25832584

2585+
@contextlib.contextmanager
2586+
def no_pre_load_state_dict_hook() -> Generator:
2587+
"""Disable the pre-load hook.
2588+
2589+
This is needed if we are loading a state_dict that was not produced by
2590+
a root FSDP instance.
2591+
"""
2592+
global _enable_pre_load_state_dict_hook
2593+
bak = _enable_pre_load_state_dict_hook
2594+
_enable_pre_load_state_dict_hook = False
2595+
yield
2596+
_enable_pre_load_state_dict_hook = bak
2597+
2598+
25842599
def _pre_load_state_dict_hook(
25852600
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
25862601
) -> None:
2587-
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
2602+
if _enable_pre_load_state_dict_hook:
2603+
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
25882604

25892605

25902606
def _clean_path(path: str) -> str:

fairscale/nn/misc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# in favor of fairscale.nn.checkpoint.checkpoint_wrapper.
1010
from fairscale.nn.checkpoint import checkpoint_wrapper
1111

12-
from .flatten_params_wrapper import FlattenParamsWrapper
12+
from .flatten_params_wrapper import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
1313
from .param_bucket import GradBucket, ParamBucket
1414

1515
__all__: List[str] = []

fairscale/nn/misc/flatten_params_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
if TYPE_CHECKING:
5050
from collections import OrderedDict # noqa: F401
5151

52+
# See no_pre_load_state_dict_hook context manager function in FSDP for more details.
53+
_enable_pre_load_state_dict_hook = True
54+
5255

5356
class FlatParameter(nn.Parameter):
5457
"""A parameter that is initialized from a list of parameters and can be
@@ -543,6 +546,8 @@ def _post_state_dict_hook(
543546
def _pre_load_state_dict_hook(
544547
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any
545548
) -> None:
549+
if not _enable_pre_load_state_dict_hook:
550+
return
546551
# Push everything down to ._fpw_module level.
547552
replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
548553
# The flat_param_* keys actually needs to move one level up.

0 commit comments

Comments
 (0)