Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[FSDP] Zero 3 Optimization Support #4903

Merged
merged 10 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions parlai/agents/hugging_face/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import Batch, TorchAgent
from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel
from parlai.utils.fsdp import is_fsdp
from parlai.utils.fsdp import is_fsdp, delay_halving


def check_hf_version(v: Tuple[int, int]) -> bool:
Expand All @@ -41,7 +41,9 @@ def check_hf_version(v: Tuple[int, int]) -> bool:
def build_t5(opt: Opt) -> T5ForConditionalGeneration:
if not check_hf_version(HF_VERSION):
raise RuntimeError('Must use transformers package >= 4.3 to use t5')
torch_dtype = torch.float16 if opt['fp16'] else torch.float32
torch_dtype = (
torch.float16 if (opt['fp16'] and not delay_halving(opt)) else torch.float32
)
try:
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'],
Expand Down
3 changes: 1 addition & 2 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,7 @@ def add_distributed_training_args(self):
)
grp.add_argument(
'--ddp-backend',
# TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753
choices=['ddp', 'zero2'],
choices=['ddp', 'zero2', 'zero3'],
default='ddp',
help=(
'Distributed backend. Zero2 can be faster but is more experimental. '
Expand Down
14 changes: 12 additions & 2 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@
from parlai.utils.distributed import is_distributed
from parlai.utils.misc import AttrDict, warn_once
from parlai.utils.io import PathManager
from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND
from parlai.utils.fsdp import (
should_sync_gradnorm,
is_fsdp,
DEFAULT_DDP_BACKEND,
PYTORCH_FSDP_AVAILABLE,
get_state_dict,
should_use_fsdp,
)
from parlai.utils.fp16 import (
SafeFP16Optimizer,
MemoryEfficientFP16Optimizer,
Expand Down Expand Up @@ -1981,8 +1988,11 @@ def state_dict(self):
if hasattr(self.model, 'module') and not is_fsdp(self.model):
# did we wrap in a DistributedDataParallel or DataParallel
states['model'] = self.model.module.state_dict()
elif is_fsdp(self.model) and PYTORCH_FSDP_AVAILABLE:
# Pytorch FSDP. Fancy Saving
states['model'] = get_state_dict(self.model)
else:
# regular model or FSDP
# regular model or non-Pytorch FSDP
states['model'] = self.model.state_dict()

if hasattr(self, 'optimizer'):
Expand Down
12 changes: 11 additions & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import TypedDict
from parlai.core.params import ParlaiParser
from abc import ABC, abstractmethod
from typing import TypeVar, List, Dict, Optional, Tuple, Set, Iterable
from typing import TypeVar, List, Dict, Optional, Tuple, Set, Iterable, Any
import math
from operator import attrgetter

Expand Down Expand Up @@ -516,6 +516,16 @@ def __init__(self, opt: Opt, shared=None):
else:
# this is not a shared instance of this class, so do full init
self.criterion = self.build_criterion()

def load_init_model() -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh it's not, that's an artifact. will delete

if init_model is not None:
# load model parameters if available
logging.info(f'Loading existing model params from {init_model}')
states = self.load(init_model)
else:
states = {}
return states

with fsdp_utils.maybe_fsdp_wrap(opt):
self.model = fsdp_utils.fsdp_wrap(self.build_model())
if self.fp16 and not fsdp_utils.delay_halving(opt):
Expand Down
6 changes: 4 additions & 2 deletions parlai/scripts/distributed_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import parlai.scripts.train_model as single_train
from parlai.core.script import ParlaiScript
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils


def setup_args():
Expand All @@ -51,8 +52,9 @@ def setup_args(cls):

def run(self):
with distributed_utils.slurm_distributed_context(self.opt) as opt:
self.train_loop = single_train.TrainLoop(opt)
return self.train_loop.train()
self.train_loop = fsdp_utils.JoinableTrainLoop(opt)
with fsdp_utils.fsdp_join(self.train_loop):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to do this for distributed_eval too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, forgot about that. also multiprocessing eval

return self.train_loop.train()


if __name__ == '__main__':
Expand Down
5 changes: 4 additions & 1 deletion parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import traceback
import parlai.scripts.train_model as single_train
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils
from parlai.core.script import ParlaiScript, register_script


Expand All @@ -41,8 +42,10 @@ def multiprocess_train(
) as opt:
# Run the actual training
opt['multiprocessing'] = True
loop = fsdp_utils.JoinableTrainLoop(opt)
try:
return single_train.TrainLoop(opt).train()
with fsdp_utils.fsdp_join(loop):
return loop.train()
except Exception:
import parlai.utils.logging as logging

Expand Down
220 changes: 195 additions & 25 deletions parlai/utils/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,46 @@
"""
Utility functions for FullyShardedDataParallel.
"""

import contextlib
import functools
import torch
import torch.distributed
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
import torch.nn

from parlai.scripts.train_model import TrainLoop
from parlai.utils.distributed import is_distributed, get_dist_group

try:
from fairscale.nn.wrap.auto_wrap import wrap
from fairscale.nn.wrap.auto_wrap import enable_wrap as fairscale_enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
import torch
import torch.distributed
import torch.distributed.fsdp
from torch.distributed.fsdp.wrap import (
wrap,
enable_wrap,
transformer_auto_wrap_policy,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
BackwardPrefetch,
)

PYTORCH_FSDP_AVAILABLE = True
FSDP_AVAILABLE = True
except ImportError:
FSDP_AVAILABLE = False
PYTORCH_FSDP_AVAILABLE = False
try:
from fairscale.nn.wrap.auto_wrap import wrap, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why keep around fairscale support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so getting rid of fairscale would force anyone who wants to use distributed training to be on pytorch >=1.12. Is that a reasonable ask, you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's your call but pytorch 1.13 is out (and I'm using it successfully)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've removed fairscale FSDP


def wrap(module, **kwargs):
return module
FSDP_AVAILABLE = True
except ImportError:
FSDP_AVAILABLE = False

def wrap(module, **kwargs):
return module


DEFAULT_DDP_BACKEND = "ddp"
Expand Down Expand Up @@ -53,26 +77,92 @@ def maybe_fsdp_wrap(opt):
yield
return

# zero3 not supported at this time. Throw an exception
if opt['ddp_backend'] == 'zero3':
raise NotImplementedError(
'--ddp-backend zero3 is not supported at this time. For details, see '
'https://github.com/facebookresearch/ParlAI/issues/3753.'
mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe'

if PYTORCH_FSDP_AVAILABLE:
# settings as of pytorch 1.13
# There is a warning in pytorch 1.13 for FSDP that is unavoidable;
# at the risk of suppressing valid warnings, just going to suppress that one.
import warnings

warnings.filterwarnings("ignore")

# sharding strategy determines zero2 or zero3
sharding_strategy = (
ShardingStrategy.FULL_SHARD
if opt['ddp_backend'] == 'zero3'
else ShardingStrategy.SHARD_GRAD_OP
)

reshard_after_forward = opt['ddp_backend'] == 'zero3'
compute_dtype = torch.float16 if opt['fp16'] else torch.float32
mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe'
fsdp_args = dict(
reshard_after_forward=reshard_after_forward,
mixed_precision=mixed_precision,
compute_dtype=compute_dtype,
state_dict_device=torch.device('cpu'),
flatten_parameters=True,
process_group=get_dist_group(),
)
with fairscale_enable_wrap(wrapper_cls=FSDP, **fsdp_args):
yield
# mp determines how to mix precision
if mixed_precision:
mp_strategy = MixedPrecision(
reduce_dtype=torch.float16,
param_dtype=torch.float16,
buffer_dtype=torch.float16,
)
else:
mp_strategy = None

# autowrap policy.
auto_wrap_policy = None
ignored_modules = None
if opt['model'] in ['bart', 'transformer/generator']:
from parlai.agents.transformer.modules.encoder import (
TransformerEncoderLayer,
)
from parlai.agents.transformer.modules.decoder import (
TransformerDecoderLayer,
)

auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
TransformerEncoderLayer,
TransformerDecoderLayer,
},
)

# backward prefetch; determines when to fetch the parameters during backward pass
# set to BACKWARD_PRE to increase throughput, at the cost of memory
backward_prefetch = BackwardPrefetch.BACKWARD_POST

# CPU offloading; this can offload parameters to the CPU
cpu_offload = None

fsdp_args = dict(
process_group=get_dist_group(),
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mp_strategy,
ignored_modules=ignored_modules,
param_init_fn=None,
device_id=opt['gpu'],
sync_module_states=False, # need this for syncing the first call; specify False because we do it manually after cuda
forward_prefetch=False, # specify true for CPU-heavy workload
limit_all_gathers=False, # specifying the default here
)
with enable_wrap(wrapper_cls=FSDP, **fsdp_args):
yield
else:
if opt['ddp_backend'] == 'zero3':
raise NotImplementedError(
'--ddp-backend zero3 is only supported on later versions of Pytorch (>= 1.12)'
)

compute_dtype = torch.float16 if opt['fp16'] else torch.float32
fsdp_args = dict(
reshard_after_forward=False, # hard code False; only use fairscale for backwards compatibility.
mixed_precision=mixed_precision,
compute_dtype=compute_dtype,
state_dict_device=torch.device('cpu'),
flatten_parameters=True,
process_group=get_dist_group(),
)
with enable_wrap(wrapper_cls=FSDP, **fsdp_args):
yield


def delay_halving(opt):
Expand Down Expand Up @@ -109,3 +199,83 @@ def fsdp_wrap(module):
Helper function for wrapping the outermost root module.
"""
return wrap(module)


def get_state_dict(model):
"""
Get the state dict from the model.

When using Pytorch FSDP, we can offload to CPU.
"""

if PYTORCH_FSDP_AVAILABLE:
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullStateDictConfig,
StateDictType,
)

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state = model.state_dict()
else:
state = model.state_dict()

return state


@contextlib.contextmanager
def fsdp_join(*args):
with Join([*args]):
yield


class JoinableTrainLoop(TrainLoop, Joinable):
"""
Joinable train loop.
"""

def __init__(self, opt):
import parlai.utils.distributed as dist_utils

super().__init__(opt)
self.__device = opt['gpu']
self.__group = dist_utils.get_dist_group()

def __call__(self):
"""
Join caller.

For now, don't do anything.
"""
Join.notify_join_context(self)

def join_hook(self, **kwargs) -> JoinHook:
"""
Return our fake join hook.
"""
return TrainLoopJoinHook(self)

@property
def join_device(self) -> torch.device:
return self.__device

@property
def join_process_group(self):
return self.__group


class TrainLoopJoinHook(JoinHook):
"""
Join hook for train loop.

Adapted from https://pytorch.org/tutorials/advanced/generic_join.html
"""

def __init__(self, train_loop: JoinableTrainLoop):
self.train_loop = train_loop

def main_hook(self):
pass

def post_hook(self, is_last_joiner: bool):
pass
9 changes: 9 additions & 0 deletions parlai/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ def skipUnlessFairseq(testfn, reason='fairseq not installed'):
return unittest.skipUnless(FAIRSEQ_AVAILABLE, reason)(testfn)


def skipUnlessPytorchFSDP(testfn, reason='pytorch fsdp unavailable'):
"""
Decorate a test to skip unless fairseq is installed.
"""
from parlai.utils.fsdp import PYTORCH_FSDP_AVAILABLE

return unittest.skipUnless(PYTORCH_FSDP_AVAILABLE, reason)(testfn)


class retry(object):
"""
Decorator for flaky tests. Test is run up to ntries times, retrying on failure.
Expand Down
Loading