Skip to content

Commit e1993fa

Browse files
BoxiangWdeepakn94ko3n1gyueshen2016ksivaman
committed
ADLR/megatron-lm!2150 - Add support for PyTorch FSDP-2
Co-authored-by: Deepak Narayanan <[email protected]> Co-authored-by: Oliver Koenig <[email protected]> Co-authored-by: James Shen <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Keshav Santhanam <[email protected]> Co-authored-by: jasonwan <[email protected]>
1 parent 64cbae5 commit e1993fa

File tree

23 files changed

+697
-99
lines changed

23 files changed

+697
-99
lines changed

megatron/core/dist_checkpointing/optimizer.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
22

3-
""" Helpers for defining sharding for optimizer states based on existing sharding for model parameters. """
3+
""" Helpers for defining sharding for optimizer states based on existing sharding
4+
for model parameters.
5+
"""
46

57
import logging
68
from copy import deepcopy
79
from dataclasses import replace
8-
from itertools import chain
9-
from typing import Dict, Iterable, List, Tuple, Union
10+
from typing import Dict, Iterable, Tuple, Union
1011

1112
logger = logging.getLogger(__name__)
1213

1314
import torch
1415

16+
from megatron.core.utils import to_local_if_dtensor
17+
1518
from .dict_utils import nested_values
1619
from .mapping import (
1720
LocalNonpersistentObject,
@@ -24,8 +27,10 @@
2427

2528

2629
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
30+
"""Generate mapping from optimizer param to optimizer state id."""
2731
param_mappings = {}
2832
for i, param in enumerate(optim_params_iter):
33+
param = to_local_if_dtensor(param)
2934
if id(param) not in param_mappings:
3035
param_mappings[id(param)] = i
3136
return param_mappings
@@ -37,7 +42,8 @@ def get_param_id_to_sharded_param_map(
3742
"""Generate mapping from optimizer state ids to model sharded parameters.
3843
3944
Args:
40-
model_sharded_state_dict: sharded state dict with all model sharded tensors (can have any structure)
45+
model_sharded_state_dict: sharded state dict with all model sharded tensors
46+
(can have any structure)
4147
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
4248
The iteration must be in the same order as in the optimizer parameters.
4349
@@ -48,6 +54,9 @@ def get_param_id_to_sharded_param_map(
4854
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
4955
id_to_sharded_param_map = {}
5056
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
57+
# If using PyTorch FSDP2 the values in model_sharded_state_dict would
58+
# have been converted to local tensors during initialization.
59+
# See the make_(tp)_sharded_tensor_for_checkpoint functions.
5160
for ten in nested_values(model_sharded_state_dict):
5261
if id(ten.data) in param_to_id_map:
5362
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
@@ -76,12 +85,14 @@ def make_sharded_optimizer_tensor(
7685
Returns:
7786
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
7887
"""
88+
optim_param = to_local_if_dtensor(optim_param)
7989
if isinstance(model_param, ShardedTensorFactory):
8090
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
8191

82-
assert (
83-
tuple(optim_param.shape) == model_param.local_shape
84-
), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
92+
assert tuple(optim_param.shape) == model_param.local_shape, (
93+
f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape '
94+
f'({model_param.local_shape})'
95+
)
8596
sh_ten = replace(
8697
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
8798
)
@@ -102,9 +113,11 @@ def optim_state_to_sharding_state(
102113
103114
Args:
104115
optim_state_dict (StateDict): optimizer state dict with
105-
state parameters under `state` key and group hyperparameters under `param_groups` -> `params` key.
106-
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids to model sharded tensors.
107-
Can be generated with `get_param_id_to_sharded_param_map` function
116+
state parameters under `state` key and group hyperparameters under
117+
`param_groups` -> `params` key.
118+
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids
119+
to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map`
120+
function.
108121
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
109122
110123
Returns:

megatron/core/dist_checkpointing/strategies/torch.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from torch.distributed._shard.sharded_tensor import Shard
1717
from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor
1818
from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties
19-
from torch.distributed._tensor import DTensor
2019
from torch.distributed.checkpoint import (
2120
BytesStorageMetadata,
2221
DefaultLoadPlanner,
@@ -34,6 +33,7 @@
3433
from torch.distributed.checkpoint.metadata import Metadata
3534
from torch.distributed.checkpoint.planner_helpers import _create_write_items
3635

36+
from ...utils import get_torch_version
3737
from ..core import CheckpointingException
3838
from ..dict_utils import nested_values
3939
from ..mapping import (
@@ -70,6 +70,13 @@
7070
except ImportError:
7171
HAVE_TE = False
7272

73+
try:
74+
from torch.distributed._tensor import DTensor
75+
76+
HAVE_DTENSOR = True
77+
except ImportError:
78+
HAVE_DTENSOR = False
79+
7380

7481
def register_default_torch_strategies():
7582
"""Register default strategies related to PyT Distributed backend."""
@@ -451,7 +458,7 @@ def __init__(
451458
) -> None:
452459
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
453460
# during saving.
454-
if PkgVersion(torch.__version__) <= PkgVersion("2.2"):
461+
if get_torch_version() <= PkgVersion("2.2"):
455462
kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors
456463
super().__init__(*args, **kwargs)
457464
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
@@ -466,7 +473,7 @@ def create_local_plan(self) -> SavePlan:
466473
# add those requests on all ranks. We inline a simplified version of this method below.
467474
write_items = []
468475
for fqn, obj in self.state_dict.items():
469-
assert not isinstance(
476+
assert not HAVE_DTENSOR or not isinstance(
470477
obj, DTensor
471478
) # translation from MCore ShardedTensors shouldn't result in DTensors
472479
# Create write requests for tensor and bytes values.

megatron/core/distributed/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## How to use pytorch FSDP2?
2+
3+
Add these flag to enable Torch FSDP2.
4+
5+
```
6+
--use-torch-fsdp2
7+
--no-gradient-accumulation-fusion
8+
--ckpt-format torch_dist
9+
```
10+
11+
It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized.

megatron/core/distributed/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
22

3+
from packaging.version import Version
4+
35
from .distributed_data_parallel import DistributedDataParallel
46
from .distributed_data_parallel_config import DistributedDataParallelConfig
57
from .finalize_model_grads import finalize_model_grads
8+
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
3+
from contextlib import contextmanager
4+
5+
import torch
6+
7+
from ..transformer.module import MegatronModule
8+
from ..transformer.transformer_config import TransformerConfig
9+
10+
11+
class _BaseDataParallel(MegatronModule):
12+
"""A template class for DistributedDataParallel implementations."""
13+
14+
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
15+
super().__init__(config=config)
16+
self.module = module
17+
18+
def forward(self, *inputs, **kwargs):
19+
"""
20+
Calls the wrapped module's forward() method.
21+
"""
22+
return self.module(*inputs, **kwargs)
23+
24+
@contextmanager
25+
def no_sync(self):
26+
"""
27+
Context manager that turns off gradient synchronization.
28+
"""
29+
try:
30+
yield
31+
finally:
32+
pass
33+
34+
def start_grad_sync(self, *unused):
35+
"""
36+
Initiates grad sync (all-reduce or reduce-scatter) communication operations
37+
for all model gradients.
38+
39+
When overlap_grad_reduce is set to True, dispatches asynchronous communication
40+
calls. When overlap_grad_reduce is set to False, calls synchronous
41+
communication ops.
42+
"""
43+
pass
44+
45+
def scale_gradients(self, scaling_factor: float) -> None:
46+
"""Scale all gradients inside the buffers by `scaling_factor`."""
47+
pass
48+
49+
def finish_grad_sync(self):
50+
"""
51+
Finishes grad sync (all-reduce or reduce-scatter) communication operations
52+
for all model gradients.
53+
54+
When overlap_grad_reduce is set to True, waits for asynchronous communication
55+
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
56+
communication ops.
57+
"""
58+
pass
59+
60+
def zero_grad_buffer(self):
61+
"""
62+
Zeros out all grad buffers. Needs to be called at the beginning of each
63+
training iteration.
64+
"""
65+
pass
66+
67+
def broadcast_params(self):
68+
"""
69+
Syncs parameters across all DP ranks.
70+
"""
71+
pass
72+
73+
def state_dict(self, prefix='', keep_vars=False):
74+
"""
75+
Returns a dictionary containing references to the whole state of the
76+
wrapped module.
77+
78+
Both parameters and persistent buffers (e.g. running averages) are included.
79+
Keys are corresponding parameter and buffer names. Parameters and buffers
80+
set to None are not included.
81+
"""
82+
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
83+
84+
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
85+
"""
86+
Returns wrapped module's state_dict for checkpoint saving.
87+
"""
88+
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
89+
90+
def load_state_dict(self, state_dict, strict=True):
91+
"""
92+
Copies parameters and buffers from state_dict into the wrapped module and its
93+
descendants. If strict is True, then the keys of state_dict must exactly match
94+
the keys returned by this module’s state_dict() function.
95+
"""
96+
self.module.load_state_dict(state_dict, strict=strict)

megatron/core/distributed/distributed_data_parallel.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88
from .. import parallel_state
99
from ..config_logger import has_config_logger_enabled, log_config_to_disk
10-
from ..transformer.module import MegatronModule
1110
from ..transformer.transformer_config import TransformerConfig
1211
from ..utils import is_float8tensor, log_single_rank
12+
from .data_parallel_base import _BaseDataParallel
1313
from .distributed_data_parallel_config import DistributedDataParallelConfig
1414
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
1515

1616
logger = logging.getLogger(__name__)
1717

1818

19-
class DistributedDataParallel(MegatronModule):
19+
class DistributedDataParallel(_BaseDataParallel):
2020
"""
2121
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
2222
communication with backprop computation by breaking up full model's gradients into smaller
@@ -41,7 +41,7 @@ def __init__(
4141
module: torch.nn.Module,
4242
disable_bucketing: bool = False,
4343
):
44-
super().__init__(config=config)
44+
super().__init__(config=config, module=module)
4545
if has_config_logger_enabled(config):
4646
log_config_to_disk(config, locals(), prefix=type(self).__name__)
4747

@@ -298,12 +298,6 @@ def disable_forward_pre_hook(self):
298298
# Force synchronize parameters.
299299
self.start_param_sync(force_sync=True)
300300

301-
def forward(self, *inputs, **kwargs):
302-
"""
303-
Calls the wrapped module's forward() method.
304-
"""
305-
return self.module(*inputs, **kwargs)
306-
307301
def _make_forward_pre_hook(self):
308302
"""
309303
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
@@ -458,28 +452,3 @@ def broadcast_params(self):
458452
src=torch.distributed.get_global_rank(data_parallel_group, 0),
459453
group=data_parallel_group,
460454
)
461-
462-
def state_dict(self, prefix='', keep_vars=False):
463-
"""
464-
Returns a dictionary containing references to the whole state of the
465-
wrapped module.
466-
467-
Both parameters and persistent buffers (e.g. running averages) are included.
468-
Keys are corresponding parameter and buffer names. Parameters and buffers
469-
set to None are not included.
470-
"""
471-
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
472-
473-
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
474-
"""
475-
Returns wrapped module's state_dict for checkpoint saving.
476-
"""
477-
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
478-
479-
def load_state_dict(self, state_dict, strict=True):
480-
"""
481-
Copies parameters and buffers from state_dict into the wrapped module and its
482-
descendants. If strict is True, then the keys of state_dict must exactly match
483-
the keys returned by this module’s state_dict() function.
484-
"""
485-
self.module.load_state_dict(state_dict, strict=strict)

0 commit comments

Comments
 (0)