Skip to content

Commit

Permalink
Merge branch 'boxiangw/fsdp2' into 'main'
Browse files Browse the repository at this point in the history
Add support for PyTorch FSDP-2

See merge request ADLR/megatron-lm!2150
  • Loading branch information
deepakn94 committed Nov 14, 2024
2 parents ae9c141 + e1993fa commit 4c4215f
Show file tree
Hide file tree
Showing 23 changed files with 697 additions and 98 deletions.
33 changes: 23 additions & 10 deletions megatron/core/dist_checkpointing/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.

""" Helpers for defining sharding for optimizer states based on existing sharding for model parameters. """
""" Helpers for defining sharding for optimizer states based on existing sharding
for model parameters.
"""

import logging
from copy import deepcopy
from dataclasses import replace
from itertools import chain
from typing import Dict, Iterable, List, Tuple, Union
from typing import Dict, Iterable, Tuple, Union

logger = logging.getLogger(__name__)

import torch

from megatron.core.utils import to_local_if_dtensor

from .dict_utils import nested_values
from .mapping import (
LocalNonpersistentObject,
Expand All @@ -24,8 +27,10 @@


def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
"""Generate mapping from optimizer param to optimizer state id."""
param_mappings = {}
for i, param in enumerate(optim_params_iter):
param = to_local_if_dtensor(param)
if id(param) not in param_mappings:
param_mappings[id(param)] = i
return param_mappings
Expand All @@ -37,7 +42,8 @@ def get_param_id_to_sharded_param_map(
"""Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors (can have any structure)
model_sharded_state_dict: sharded state dict with all model sharded tensors
(can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Expand All @@ -48,6 +54,9 @@ def get_param_id_to_sharded_param_map(
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
# If using PyTorch FSDP2 the values in model_sharded_state_dict would
# have been converted to local tensors during initialization.
# See the make_(tp)_sharded_tensor_for_checkpoint functions.
for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
Expand Down Expand Up @@ -76,12 +85,14 @@ def make_sharded_optimizer_tensor(
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
optim_param = to_local_if_dtensor(optim_param)
if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)

assert (
tuple(optim_param.shape) == model_param.local_shape
), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
assert tuple(optim_param.shape) == model_param.local_shape, (
f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape '
f'({model_param.local_shape})'
)
sh_ten = replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
)
Expand All @@ -102,9 +113,11 @@ def optim_state_to_sharding_state(
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under `param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids to model sharded tensors.
Can be generated with `get_param_id_to_sharded_param_map` function
state parameters under `state` key and group hyperparameters under
`param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids
to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map`
function.
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
Expand Down
13 changes: 10 additions & 3 deletions megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch.distributed._shard.sharded_tensor import Shard
from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint import (
BytesStorageMetadata,
DefaultLoadPlanner,
Expand All @@ -34,6 +33,7 @@
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.planner_helpers import _create_write_items

from ...utils import get_torch_version
from ..core import CheckpointingException
from ..dict_utils import nested_values
from ..mapping import (
Expand Down Expand Up @@ -70,6 +70,13 @@
except ImportError:
HAVE_TE = False

try:
from torch.distributed._tensor import DTensor

HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False


def register_default_torch_strategies():
"""Register default strategies related to PyT Distributed backend."""
Expand Down Expand Up @@ -451,7 +458,7 @@ def __init__(
) -> None:
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
# during saving.
if PkgVersion(torch.__version__) <= PkgVersion("2.2"):
if get_torch_version() <= PkgVersion("2.2"):
kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors
super().__init__(*args, **kwargs)
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
Expand All @@ -466,7 +473,7 @@ def create_local_plan(self) -> SavePlan:
# add those requests on all ranks. We inline a simplified version of this method below.
write_items = []
for fqn, obj in self.state_dict.items():
assert not isinstance(
assert not HAVE_DTENSOR or not isinstance(
obj, DTensor
) # translation from MCore ShardedTensors shouldn't result in DTensors
# Create write requests for tensor and bytes values.
Expand Down
11 changes: 11 additions & 0 deletions megatron/core/distributed/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## How to use pytorch FSDP2?

Add these flag to enable Torch FSDP2.

```
--use-torch-fsdp2
--no-gradient-accumulation-fusion
--ckpt-format torch_dist
```

It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized.
3 changes: 3 additions & 0 deletions megatron/core/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from packaging.version import Version

from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
96 changes: 96 additions & 0 deletions megatron/core/distributed/data_parallel_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from contextlib import contextmanager

import torch

from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig


class _BaseDataParallel(MegatronModule):
"""A template class for DistributedDataParallel implementations."""

def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super().__init__(config=config)
self.module = module

def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)

@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
try:
yield
finally:
pass

def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass

def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
pass

def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass

def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
pass

def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
pass

def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)

def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)

def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
37 changes: 3 additions & 34 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank
from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets

logger = logging.getLogger(__name__)


class DistributedDataParallel(MegatronModule):
class DistributedDataParallel(_BaseDataParallel):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
Expand All @@ -41,7 +41,7 @@ def __init__(
module: torch.nn.Module,
disable_bucketing: bool = False,
):
super().__init__(config=config)
super().__init__(config=config, module=module)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)

Expand Down Expand Up @@ -298,12 +298,6 @@ def disable_forward_pre_hook(self):
# Force synchronize parameters.
self.start_param_sync(force_sync=True)

def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)

def _make_forward_pre_hook(self):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
Expand Down Expand Up @@ -458,28 +452,3 @@ def broadcast_params(self):
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)

def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)

def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)

def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
Loading

0 comments on commit 4c4215f

Please sign in to comment.