-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'boxiangw/fsdp2' into 'main'
Add support for PyTorch FSDP-2 See merge request ADLR/megatron-lm!2150
- Loading branch information
Showing
23 changed files
with
697 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
## How to use pytorch FSDP2? | ||
|
||
Add these flag to enable Torch FSDP2. | ||
|
||
``` | ||
--use-torch-fsdp2 | ||
--no-gradient-accumulation-fusion | ||
--ckpt-format torch_dist | ||
``` | ||
|
||
It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
|
||
from packaging.version import Version | ||
|
||
from .distributed_data_parallel import DistributedDataParallel | ||
from .distributed_data_parallel_config import DistributedDataParallelConfig | ||
from .finalize_model_grads import finalize_model_grads | ||
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
|
||
from contextlib import contextmanager | ||
|
||
import torch | ||
|
||
from ..transformer.module import MegatronModule | ||
from ..transformer.transformer_config import TransformerConfig | ||
|
||
|
||
class _BaseDataParallel(MegatronModule): | ||
"""A template class for DistributedDataParallel implementations.""" | ||
|
||
def __init__(self, config: TransformerConfig, module: torch.nn.Module): | ||
super().__init__(config=config) | ||
self.module = module | ||
|
||
def forward(self, *inputs, **kwargs): | ||
""" | ||
Calls the wrapped module's forward() method. | ||
""" | ||
return self.module(*inputs, **kwargs) | ||
|
||
@contextmanager | ||
def no_sync(self): | ||
""" | ||
Context manager that turns off gradient synchronization. | ||
""" | ||
try: | ||
yield | ||
finally: | ||
pass | ||
|
||
def start_grad_sync(self, *unused): | ||
""" | ||
Initiates grad sync (all-reduce or reduce-scatter) communication operations | ||
for all model gradients. | ||
When overlap_grad_reduce is set to True, dispatches asynchronous communication | ||
calls. When overlap_grad_reduce is set to False, calls synchronous | ||
communication ops. | ||
""" | ||
pass | ||
|
||
def scale_gradients(self, scaling_factor: float) -> None: | ||
"""Scale all gradients inside the buffers by `scaling_factor`.""" | ||
pass | ||
|
||
def finish_grad_sync(self): | ||
""" | ||
Finishes grad sync (all-reduce or reduce-scatter) communication operations | ||
for all model gradients. | ||
When overlap_grad_reduce is set to True, waits for asynchronous communication | ||
calls to complete. When overlap_grad_reduce is set to False, calls synchronous | ||
communication ops. | ||
""" | ||
pass | ||
|
||
def zero_grad_buffer(self): | ||
""" | ||
Zeros out all grad buffers. Needs to be called at the beginning of each | ||
training iteration. | ||
""" | ||
pass | ||
|
||
def broadcast_params(self): | ||
""" | ||
Syncs parameters across all DP ranks. | ||
""" | ||
pass | ||
|
||
def state_dict(self, prefix='', keep_vars=False): | ||
""" | ||
Returns a dictionary containing references to the whole state of the | ||
wrapped module. | ||
Both parameters and persistent buffers (e.g. running averages) are included. | ||
Keys are corresponding parameter and buffer names. Parameters and buffers | ||
set to None are not included. | ||
""" | ||
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) | ||
|
||
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): | ||
""" | ||
Returns wrapped module's state_dict for checkpoint saving. | ||
""" | ||
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) | ||
|
||
def load_state_dict(self, state_dict, strict=True): | ||
""" | ||
Copies parameters and buffers from state_dict into the wrapped module and its | ||
descendants. If strict is True, then the keys of state_dict must exactly match | ||
the keys returned by this module’s state_dict() function. | ||
""" | ||
self.module.load_state_dict(state_dict, strict=strict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.