-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[FSDP] Zero 3 Optimization Support #4903
Changes from 2 commits
9d5a413
b26daf9
e5190a3
5d2b561
ceab726
8378a9b
8c026d2
3edeec3
67175da
84b501e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
import parlai.scripts.train_model as single_train | ||
from parlai.core.script import ParlaiScript | ||
import parlai.utils.distributed as distributed_utils | ||
import parlai.utils.fsdp as fsdp_utils | ||
|
||
|
||
def setup_args(): | ||
|
@@ -51,8 +52,9 @@ def setup_args(cls): | |
|
||
def run(self): | ||
with distributed_utils.slurm_distributed_context(self.opt) as opt: | ||
self.train_loop = single_train.TrainLoop(opt) | ||
return self.train_loop.train() | ||
self.train_loop = fsdp_utils.JoinableTrainLoop(opt) | ||
with fsdp_utils.fsdp_join(self.train_loop): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you need to do this for distributed_eval too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, forgot about that. also multiprocessing eval |
||
return self.train_loop.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,22 +7,46 @@ | |
""" | ||
Utility functions for FullyShardedDataParallel. | ||
""" | ||
|
||
import contextlib | ||
import functools | ||
import torch | ||
import torch.distributed | ||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook | ||
import torch.nn | ||
|
||
from parlai.scripts.train_model import TrainLoop | ||
from parlai.utils.distributed import is_distributed, get_dist_group | ||
|
||
try: | ||
from fairscale.nn.wrap.auto_wrap import wrap | ||
from fairscale.nn.wrap.auto_wrap import enable_wrap as fairscale_enable_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | ||
import torch | ||
import torch.distributed | ||
import torch.distributed.fsdp | ||
from torch.distributed.fsdp.wrap import ( | ||
wrap, | ||
enable_wrap, | ||
transformer_auto_wrap_policy, | ||
) | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import ( | ||
FullyShardedDataParallel as FSDP, | ||
ShardingStrategy, | ||
MixedPrecision, | ||
BackwardPrefetch, | ||
) | ||
|
||
PYTORCH_FSDP_AVAILABLE = True | ||
FSDP_AVAILABLE = True | ||
except ImportError: | ||
FSDP_AVAILABLE = False | ||
PYTORCH_FSDP_AVAILABLE = False | ||
try: | ||
from fairscale.nn.wrap.auto_wrap import wrap, enable_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why keep around fairscale support? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so getting rid of fairscale would force anyone who wants to use distributed training to be on pytorch >=1.12. Is that a reasonable ask, you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's your call but pytorch 1.13 is out (and I'm using it successfully) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i've removed fairscale FSDP |
||
|
||
def wrap(module, **kwargs): | ||
return module | ||
FSDP_AVAILABLE = True | ||
except ImportError: | ||
FSDP_AVAILABLE = False | ||
|
||
def wrap(module, **kwargs): | ||
return module | ||
|
||
|
||
DEFAULT_DDP_BACKEND = "ddp" | ||
|
@@ -53,26 +77,92 @@ def maybe_fsdp_wrap(opt): | |
yield | ||
return | ||
|
||
# zero3 not supported at this time. Throw an exception | ||
if opt['ddp_backend'] == 'zero3': | ||
raise NotImplementedError( | ||
'--ddp-backend zero3 is not supported at this time. For details, see ' | ||
'https://github.com/facebookresearch/ParlAI/issues/3753.' | ||
mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe' | ||
|
||
if PYTORCH_FSDP_AVAILABLE: | ||
# settings as of pytorch 1.13 | ||
# There is a warning in pytorch 1.13 for FSDP that is unavoidable; | ||
# at the risk of suppressing valid warnings, just going to suppress that one. | ||
import warnings | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
# sharding strategy determines zero2 or zero3 | ||
sharding_strategy = ( | ||
ShardingStrategy.FULL_SHARD | ||
if opt['ddp_backend'] == 'zero3' | ||
else ShardingStrategy.SHARD_GRAD_OP | ||
) | ||
|
||
reshard_after_forward = opt['ddp_backend'] == 'zero3' | ||
compute_dtype = torch.float16 if opt['fp16'] else torch.float32 | ||
mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe' | ||
fsdp_args = dict( | ||
reshard_after_forward=reshard_after_forward, | ||
mixed_precision=mixed_precision, | ||
compute_dtype=compute_dtype, | ||
state_dict_device=torch.device('cpu'), | ||
flatten_parameters=True, | ||
process_group=get_dist_group(), | ||
) | ||
with fairscale_enable_wrap(wrapper_cls=FSDP, **fsdp_args): | ||
yield | ||
# mp determines how to mix precision | ||
if mixed_precision: | ||
mp_strategy = MixedPrecision( | ||
reduce_dtype=torch.float16, | ||
param_dtype=torch.float16, | ||
buffer_dtype=torch.float16, | ||
) | ||
else: | ||
mp_strategy = None | ||
|
||
# autowrap policy. | ||
auto_wrap_policy = None | ||
ignored_modules = None | ||
if opt['model'] in ['bart', 'transformer/generator']: | ||
from parlai.agents.transformer.modules.encoder import ( | ||
TransformerEncoderLayer, | ||
) | ||
from parlai.agents.transformer.modules.decoder import ( | ||
TransformerDecoderLayer, | ||
) | ||
|
||
auto_wrap_policy = functools.partial( | ||
transformer_auto_wrap_policy, | ||
transformer_layer_cls={ | ||
TransformerEncoderLayer, | ||
TransformerDecoderLayer, | ||
}, | ||
) | ||
|
||
# backward prefetch; determines when to fetch the parameters during backward pass | ||
# set to BACKWARD_PRE to increase throughput, at the cost of memory | ||
backward_prefetch = BackwardPrefetch.BACKWARD_POST | ||
|
||
# CPU offloading; this can offload parameters to the CPU | ||
cpu_offload = None | ||
|
||
fsdp_args = dict( | ||
process_group=get_dist_group(), | ||
sharding_strategy=sharding_strategy, | ||
cpu_offload=cpu_offload, | ||
auto_wrap_policy=auto_wrap_policy, | ||
backward_prefetch=backward_prefetch, | ||
mixed_precision=mp_strategy, | ||
ignored_modules=ignored_modules, | ||
param_init_fn=None, | ||
device_id=opt['gpu'], | ||
sync_module_states=False, # need this for syncing the first call; specify False because we do it manually after cuda | ||
forward_prefetch=False, # specify true for CPU-heavy workload | ||
limit_all_gathers=False, # specifying the default here | ||
) | ||
with enable_wrap(wrapper_cls=FSDP, **fsdp_args): | ||
yield | ||
else: | ||
if opt['ddp_backend'] == 'zero3': | ||
raise NotImplementedError( | ||
'--ddp-backend zero3 is only supported on later versions of Pytorch (>= 1.12)' | ||
) | ||
|
||
compute_dtype = torch.float16 if opt['fp16'] else torch.float32 | ||
fsdp_args = dict( | ||
reshard_after_forward=False, # hard code False; only use fairscale for backwards compatibility. | ||
mixed_precision=mixed_precision, | ||
compute_dtype=compute_dtype, | ||
state_dict_device=torch.device('cpu'), | ||
flatten_parameters=True, | ||
process_group=get_dist_group(), | ||
) | ||
with enable_wrap(wrapper_cls=FSDP, **fsdp_args): | ||
yield | ||
|
||
|
||
def delay_halving(opt): | ||
|
@@ -109,3 +199,83 @@ def fsdp_wrap(module): | |
Helper function for wrapping the outermost root module. | ||
""" | ||
return wrap(module) | ||
|
||
|
||
def get_state_dict(model): | ||
""" | ||
Get the state dict from the model. | ||
|
||
When using Pytorch FSDP, we can offload to CPU. | ||
""" | ||
|
||
if PYTORCH_FSDP_AVAILABLE: | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import ( | ||
FullStateDictConfig, | ||
StateDictType, | ||
) | ||
|
||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | ||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): | ||
state = model.state_dict() | ||
else: | ||
state = model.state_dict() | ||
|
||
return state | ||
|
||
|
||
@contextlib.contextmanager | ||
def fsdp_join(*args): | ||
with Join([*args]): | ||
yield | ||
|
||
|
||
class JoinableTrainLoop(TrainLoop, Joinable): | ||
""" | ||
Joinable train loop. | ||
""" | ||
|
||
def __init__(self, opt): | ||
import parlai.utils.distributed as dist_utils | ||
|
||
super().__init__(opt) | ||
self.__device = opt['gpu'] | ||
self.__group = dist_utils.get_dist_group() | ||
|
||
def __call__(self): | ||
""" | ||
Join caller. | ||
|
||
For now, don't do anything. | ||
""" | ||
Join.notify_join_context(self) | ||
|
||
def join_hook(self, **kwargs) -> JoinHook: | ||
""" | ||
Return our fake join hook. | ||
""" | ||
return TrainLoopJoinHook(self) | ||
|
||
@property | ||
def join_device(self) -> torch.device: | ||
return self.__device | ||
|
||
@property | ||
def join_process_group(self): | ||
return self.__group | ||
|
||
|
||
class TrainLoopJoinHook(JoinHook): | ||
""" | ||
Join hook for train loop. | ||
|
||
Adapted from https://pytorch.org/tutorials/advanced/generic_join.html | ||
""" | ||
|
||
def __init__(self, train_loop: JoinableTrainLoop): | ||
self.train_loop = train_loop | ||
|
||
def main_hook(self): | ||
pass | ||
|
||
def post_hook(self, is_last_joiner: bool): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh it's not, that's an artifact. will delete