diff --git a/megatron/core/dist_checkpointing/optimizer.py b/megatron/core/dist_checkpointing/optimizer.py index 2d231a24ff..b3fcc7c645 100644 --- a/megatron/core/dist_checkpointing/optimizer.py +++ b/megatron/core/dist_checkpointing/optimizer.py @@ -1,17 +1,20 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. -""" Helpers for defining sharding for optimizer states based on existing sharding for model parameters. """ +""" Helpers for defining sharding for optimizer states based on existing sharding +for model parameters. +""" import logging from copy import deepcopy from dataclasses import replace -from itertools import chain -from typing import Dict, Iterable, List, Tuple, Union +from typing import Dict, Iterable, Tuple, Union logger = logging.getLogger(__name__) import torch +from megatron.core.utils import to_local_if_dtensor + from .dict_utils import nested_values from .mapping import ( LocalNonpersistentObject, @@ -24,8 +27,10 @@ def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]: + """Generate mapping from optimizer param to optimizer state id.""" param_mappings = {} for i, param in enumerate(optim_params_iter): + param = to_local_if_dtensor(param) if id(param) not in param_mappings: param_mappings[id(param)] = i return param_mappings @@ -37,7 +42,8 @@ def get_param_id_to_sharded_param_map( """Generate mapping from optimizer state ids to model sharded parameters. Args: - model_sharded_state_dict: sharded state dict with all model sharded tensors (can have any structure) + model_sharded_state_dict: sharded state dict with all model sharded tensors + (can have any structure) optim_params_iter: iterable which iterates over model parameters tracked by the optimizer. The iteration must be in the same order as in the optimizer parameters. @@ -48,6 +54,9 @@ def get_param_id_to_sharded_param_map( model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict) id_to_sharded_param_map = {} param_to_id_map = get_optim_param_to_id_map(optim_params_iter) + # If using PyTorch FSDP2 the values in model_sharded_state_dict would + # have been converted to local tensors during initialization. + # See the make_(tp)_sharded_tensor_for_checkpoint functions. for ten in nested_values(model_sharded_state_dict): if id(ten.data) in param_to_id_map: id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten @@ -76,12 +85,14 @@ def make_sharded_optimizer_tensor( Returns: Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter """ + optim_param = to_local_if_dtensor(optim_param) if isinstance(model_param, ShardedTensorFactory): return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param) - assert ( - tuple(optim_param.shape) == model_param.local_shape - ), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})' + assert tuple(optim_param.shape) == model_param.local_shape, ( + f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ' + f'({model_param.local_shape})' + ) sh_ten = replace( model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype ) @@ -102,9 +113,11 @@ def optim_state_to_sharding_state( Args: optim_state_dict (StateDict): optimizer state dict with - state parameters under `state` key and group hyperparameters under `param_groups` -> `params` key. - id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids to model sharded tensors. - Can be generated with `get_param_id_to_sharded_param_map` function + state parameters under `state` key and group hyperparameters under + `param_groups` -> `params` key. + id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids + to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map` + function. exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict. Returns: diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index 01f6923ae7..d7ec055a08 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -16,7 +16,6 @@ from torch.distributed._shard.sharded_tensor import Shard from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties -from torch.distributed._tensor import DTensor from torch.distributed.checkpoint import ( BytesStorageMetadata, DefaultLoadPlanner, @@ -34,6 +33,7 @@ from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.checkpoint.planner_helpers import _create_write_items +from ...utils import get_torch_version from ..core import CheckpointingException from ..dict_utils import nested_values from ..mapping import ( @@ -70,6 +70,13 @@ except ImportError: HAVE_TE = False +try: + from torch.distributed._tensor import DTensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + def register_default_torch_strategies(): """Register default strategies related to PyT Distributed backend.""" @@ -451,7 +458,7 @@ def __init__( ) -> None: # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings # during saving. - if PkgVersion(torch.__version__) <= PkgVersion("2.2"): + if get_torch_version() <= PkgVersion("2.2"): kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors super().__init__(*args, **kwargs) self.nd_flattened_global_shapes = nd_flattened_global_shapes or {} @@ -466,7 +473,7 @@ def create_local_plan(self) -> SavePlan: # add those requests on all ranks. We inline a simplified version of this method below. write_items = [] for fqn, obj in self.state_dict.items(): - assert not isinstance( + assert not HAVE_DTENSOR or not isinstance( obj, DTensor ) # translation from MCore ShardedTensors shouldn't result in DTensors # Create write requests for tensor and bytes values. diff --git a/megatron/core/distributed/README.md b/megatron/core/distributed/README.md new file mode 100644 index 0000000000..c4a7528441 --- /dev/null +++ b/megatron/core/distributed/README.md @@ -0,0 +1,11 @@ +## How to use pytorch FSDP2? + +Add these flag to enable Torch FSDP2. + +``` +--use-torch-fsdp2 +--no-gradient-accumulation-fusion +--ckpt-format torch_dist +``` + +It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized. diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py index 3d4780d5b4..9dbf83c80d 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from packaging.version import Version + from .distributed_data_parallel import DistributedDataParallel from .distributed_data_parallel_config import DistributedDataParallelConfig from .finalize_model_grads import finalize_model_grads +from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel diff --git a/megatron/core/distributed/data_parallel_base.py b/megatron/core/distributed/data_parallel_base.py new file mode 100644 index 0000000000..aed576a7a3 --- /dev/null +++ b/megatron/core/distributed/data_parallel_base.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from contextlib import contextmanager + +import torch + +from ..transformer.module import MegatronModule +from ..transformer.transformer_config import TransformerConfig + + +class _BaseDataParallel(MegatronModule): + """A template class for DistributedDataParallel implementations.""" + + def __init__(self, config: TransformerConfig, module: torch.nn.Module): + super().__init__(config=config) + self.module = module + + def forward(self, *inputs, **kwargs): + """ + Calls the wrapped module's forward() method. + """ + return self.module(*inputs, **kwargs) + + @contextmanager + def no_sync(self): + """ + Context manager that turns off gradient synchronization. + """ + try: + yield + finally: + pass + + def start_grad_sync(self, *unused): + """ + Initiates grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, dispatches asynchronous communication + calls. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + pass + + def scale_gradients(self, scaling_factor: float) -> None: + """Scale all gradients inside the buffers by `scaling_factor`.""" + pass + + def finish_grad_sync(self): + """ + Finishes grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, waits for asynchronous communication + calls to complete. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + pass + + def zero_grad_buffer(self): + """ + Zeros out all grad buffers. Needs to be called at the beginning of each + training iteration. + """ + pass + + def broadcast_params(self): + """ + Syncs parameters across all DP ranks. + """ + pass + + def state_dict(self, prefix='', keep_vars=False): + """ + Returns a dictionary containing references to the whole state of the + wrapped module. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. Parameters and buffers + set to None are not included. + """ + return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """ + Returns wrapped module's state_dict for checkpoint saving. + """ + return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict, strict=True): + """ + Copies parameters and buffers from state_dict into the wrapped module and its + descendants. If strict is True, then the keys of state_dict must exactly match + the keys returned by this module’s state_dict() function. + """ + self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 6e5bbd96d7..5c9e1df842 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -7,16 +7,16 @@ from .. import parallel_state from ..config_logger import has_config_logger_enabled, log_config_to_disk -from ..transformer.module import MegatronModule from ..transformer.transformer_config import TransformerConfig from ..utils import is_float8tensor, log_single_rank +from .data_parallel_base import _BaseDataParallel from .distributed_data_parallel_config import DistributedDataParallelConfig from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets logger = logging.getLogger(__name__) -class DistributedDataParallel(MegatronModule): +class DistributedDataParallel(_BaseDataParallel): """ DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping communication with backprop computation by breaking up full model's gradients into smaller @@ -41,7 +41,7 @@ def __init__( module: torch.nn.Module, disable_bucketing: bool = False, ): - super().__init__(config=config) + super().__init__(config=config, module=module) if has_config_logger_enabled(config): log_config_to_disk(config, locals(), prefix=type(self).__name__) @@ -298,12 +298,6 @@ def disable_forward_pre_hook(self): # Force synchronize parameters. self.start_param_sync(force_sync=True) - def forward(self, *inputs, **kwargs): - """ - Calls the wrapped module's forward() method. - """ - return self.module(*inputs, **kwargs) - def _make_forward_pre_hook(self): """ Create a forward pre-hook to wait on all-gather handles when necessary (i.e., @@ -458,28 +452,3 @@ def broadcast_params(self): src=torch.distributed.get_global_rank(data_parallel_group, 0), group=data_parallel_group, ) - - def state_dict(self, prefix='', keep_vars=False): - """ - Returns a dictionary containing references to the whole state of the - wrapped module. - - Both parameters and persistent buffers (e.g. running averages) are included. - Keys are corresponding parameter and buffer names. Parameters and buffers - set to None are not included. - """ - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """ - Returns wrapped module's state_dict for checkpoint saving. - """ - return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) - - def load_state_dict(self, state_dict, strict=True): - """ - Copies parameters and buffers from state_dict into the wrapped module and its - descendants. If strict is True, then the keys of state_dict must exactly match - the keys returned by this module’s state_dict() function. - """ - self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index 2cbcf84a7b..199366c80b 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -1,15 +1,69 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from typing import List, Optional +from typing import List, Optional, Union import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +try: + from torch.distributed._tensor import DTensor, distribute_tensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + from .. import parallel_state from ..transformer.transformer_config import TransformerConfig from ..utils import get_attr_wrapped_model, get_model_config +def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor: + """ + Unshards the input tensor if it is a DTensor and otherwise returns the + tensor unmodified. + + Args: + tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard. + + Returns: + An unsharded version of the input tensor if it is a DTensor, or the + input tensor unmodified if it is not a DTensor. + """ + if HAVE_DTENSOR and isinstance(tensor, DTensor): + unsharded_tensor = tensor.full_tensor() + for k, v in vars(tensor).items(): + setattr(unsharded_tensor, k, v) + return unsharded_tensor + return tensor + + +def _reshard_if_dtensor( + tensor_to_shard: torch.Tensor, reference_tensor: Union[torch.Tensor, "DTensor"] +) -> Union[torch.Tensor, "DTensor"]: + """ + Reshards the input tensor to match the sharding configuration of the + reference tensor if the reference tensor is a DTensor. Otherwise, returns + the reference tensor unmodified. + + Args: + tensor_to_shard (torch.Tensor): The tensor to be potentially sharded. + reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor + for the sharding configuration. + + Returns: + Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's + configuration, or the reference tensor itself if it is not a DTensor. + """ + if HAVE_DTENSOR and isinstance(reference_tensor, DTensor): + sharded_tensor = distribute_tensor( + tensor_to_shard, + device_mesh=reference_tensor.device_mesh, + placements=reference_tensor.placements, + ) + for k, v in vars(reference_tensor).items(): + setattr(sharded_tensor, k, v) + return sharded_tensor + return reference_tensor def _allreduce_conditional_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): """ All-reduce conditional embedding grads. @@ -73,8 +127,11 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) if model_module.share_embeddings_and_output_weights: weight = model_module.shared_embedding_or_output_weight() - grad = weight.main_grad + grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" + orig_grad = getattr(weight, grad_attr) + grad = _unshard_if_dtensor(orig_grad) torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) + setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): @@ -95,8 +152,12 @@ def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: Tr model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) assert hasattr(model_module, 'position_embeddings') - grad = model_module.position_embeddings.weight.main_grad + weight = model_module.position_embeddings.weight + grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" + orig_grad = getattr(weight, grad_attr) + grad = _unshard_if_dtensor(orig_grad) torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group()) + setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): @@ -117,6 +178,7 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer if parallel_state.get_tensor_model_parallel_world_size() > 1 and ( config.sequence_parallel or config.qk_layernorm ): + params = [] grads = [] for model_chunk in model: for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): @@ -126,15 +188,23 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer or 'q_layernorm' in name or 'k_layernorm' in name ): - grad = param.main_grad + params.append(param) + grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" + grad = getattr(param, grad_attr) + grad = _unshard_if_dtensor(grad) grads.append(grad.data) if grads: coalesced = _flatten_dense_tensors(grads) torch.distributed.all_reduce( coalesced, group=parallel_state.get_tensor_model_parallel_group() ) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + for param, buf, synced in zip( + params, grads, _unflatten_dense_tensors(coalesced, grads) + ): buf.copy_(synced) + grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" + orig_grad = getattr(param, grad_attr) + setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad)) def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None): diff --git a/megatron/core/distributed/torch_fully_sharded_data_parallel.py b/megatron/core/distributed/torch_fully_sharded_data_parallel.py new file mode 100644 index 0000000000..6d2e84e77b --- /dev/null +++ b/megatron/core/distributed/torch_fully_sharded_data_parallel.py @@ -0,0 +1,115 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import List + +import torch + +try: + from torch.distributed import DeviceMesh + from torch.distributed._composable.fsdp import fully_shard + + HAVE_FSDP = True +except ImportError: + HAVE_FSDP = False + +from .. import parallel_state, tensor_parallel +from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from ..transformer.transformer_config import TransformerConfig +from ..transformer.transformer_layer import TransformerLayer +from .data_parallel_base import _BaseDataParallel + + +class TorchFullyShardedDataParallel(_BaseDataParallel): + """ + Enables fully sharded data parallelism by wrapping the given model with + the PyTorch FSDP2 API: + https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md + To utilize this class, PyTorch version >= 2.4.0 is required. + + Args: + config: Transformer config object. + module: Underlying model. + sub_modules_to_wrap: List of sub_modules to shard with FSDP. + Parameters within each sub_module will be all-gathered just-in-time. + The default list includes the following submodules derived from the + GPT model architecture: + TransformerLayer (all Transformer layers) + LanguageModelEmbedding (initial embedding layer) + RotaryEmbedding (initial RoPE layer) + tensor_parallel.ColumnParallelLinear (final output layer) + """ + + def __init__( + self, + config: TransformerConfig, + module: torch.nn.Module, + sub_modules_to_wrap: List[torch.nn.Module] = [ + TransformerLayer, + LanguageModelEmbedding, + RotaryEmbedding, + tensor_parallel.ColumnParallelLinear, + ], + **kwargs + ): + + assert ( + HAVE_FSDP + ), 'TorchFullyShardedDataParallel requires PyTorch >= 2.4.0 with FSDP 2 support.' + + super().__init__(config=config, module=module) + self.data_parallel_group = parallel_state.get_data_parallel_group( + with_context_parallel=True + ) + + mesh = DeviceMesh.from_group(self.data_parallel_group, "cuda") + + kwargs = {"mesh": mesh} + + def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + # Save the custom attributes on Parameters before FSDP overwrites them. + # See https://github.com/pytorch/pytorch/issues/136929. + attrs = save_custom_attrs(self.module) + + prev_module = None + for sub_module in self.module.modules(): + # Wrap individual submodules to fetch parameters just-in-time rather than + # conservatively fetching all parameters at the start of each iteration. + # See https://github.com/pytorch/pytorch/issues/114299. + if any( + isinstance(sub_module, sub_module_to_wrap) + for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, **kwargs) + + # Explicitly set the FSDP backward prefetch schedule to prevent activation + # recomputation from disrupting the automatically generated default schedule. + if config.recompute_granularity is not None: + sub_module.set_modules_to_backward_prefetch( + [prev_module] if prev_module else [] + ) + prev_module = sub_module + + # Wrap the root module as required by the FSDP API. + # See https://github.com/pytorch/pytorch/issues/114299. + fully_shard(self.module, **kwargs) + + restore_custom_attrs(self.module, attrs) + + def load_state_dict(self, state_dict, strict=True): + """ + No-op because tensors are already loaded in-place by + `_load_base_checkpoint` with FSDP2.""" + pass diff --git a/megatron/core/optimizer/clip_grads.py b/megatron/core/optimizer/clip_grads.py index 5308b5412f..5c3a6578f4 100644 --- a/megatron/core/optimizer/clip_grads.py +++ b/megatron/core/optimizer/clip_grads.py @@ -45,6 +45,7 @@ from ..tensor_parallel import param_is_not_tensor_parallel_duplicate from ..transformer.module import param_is_not_shared +from ..utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor def get_grad_norm_fp32( @@ -73,6 +74,12 @@ def get_grad_norm_fp32( if isinstance(grads_for_norm, torch.Tensor): grads_for_norm = [grads_for_norm] + data_parallel_group = None + for grad in grads_for_norm: + data_parallel_group = get_data_parallel_group_if_dtensor(grad, data_parallel_group) + + grads_for_norm = [to_local_if_dtensor(grad) for grad in grads_for_norm] + # Norm parameters. norm_type = float(norm_type) total_norm = 0.0 @@ -81,7 +88,11 @@ def get_grad_norm_fp32( if norm_type == inf: total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.tensor([float(total_norm)], dtype=torch.float, device='cuda') - # Take max across all model-parallel GPUs. + # Take max across all data-parallel GPUs if using FSDP and then all model-parallel GPUs. + if data_parallel_group: + torch.distributed.all_reduce( + total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=data_parallel_group + ) torch.distributed.all_reduce( total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grad_stats_parallel_group ) @@ -111,7 +122,11 @@ def get_grad_norm_fp32( grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm**norm_type - # Sum across all model-parallel GPUs. + # Sum across all data-parallel GPUs if using FSDP and then all model-parallel GPUs. + if data_parallel_group: + torch.distributed.all_reduce( + total_norm, op=torch.distributed.ReduceOp.SUM, group=data_parallel_group + ) torch.distributed.all_reduce( total_norm, op=torch.distributed.ReduceOp.SUM, group=grad_stats_parallel_group ) @@ -136,11 +151,13 @@ def clip_grad_by_total_norm_fp32( total_norm (float): total norm of the gradients. """ # Grads. + params = [] grads = [] for param in parameters: if param.grad is not None: assert param.grad.type() == 'torch.cuda.FloatTensor' - grads.append(param.grad.detach()) + params.append(param) + grads.append(to_local_if_dtensor(param.grad).detach()) # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) @@ -175,15 +192,24 @@ def count_zeros_fp32( # - parameter should not be shared # - should not be a replica due to tensor model parallelism total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda') + data_parallel_group = None for param in parameters: grad_not_none = param.grad is not None is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if grad_not_none and is_not_shared and is_not_tp_duplicate: - grad = param.grad.detach() + data_parallel_group = get_data_parallel_group_if_dtensor( + param.grad, data_parallel_group + ) + grad = to_local_if_dtensor(param.grad).detach() num_zeros = grad.numel() - torch.count_nonzero(grad) total_num_zeros = num_zeros + total_num_zeros + # Sum across all data-parallel GPUs if using FSDP. + if data_parallel_group: + torch.distributed.all_reduce( + total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=data_parallel_group + ) # Sum across all model-parallel GPUs. torch.distributed.all_reduce( total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=grad_stats_parallel_group diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 668cefe453..af9861396e 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -757,7 +757,8 @@ def prepare_grads(self) -> bool: ) for param_group in self.optimizer.param_groups: for param in param_group['params']: - param.grad = param.main_grad + if hasattr(param, 'main_grad'): + param.grad = param.main_grad if timers is not None: timers('optimizer-copy-to-main-grad').stop() diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index e82d6ecd20..cead6d466a 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -1,13 +1,12 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import torch import torch.nn.functional as F -from megatron.core import parallel_state from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.mapping import ( ReplicaId, @@ -20,7 +19,6 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @dataclass @@ -59,7 +57,8 @@ def __init__( self.input_size = input_size if input_size != None else self.config.hidden_size - # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf + # If this is a gated linear unit we double the output width + # see https://arxiv.org/pdf/2002.05202.pdf ffn_hidden_size = self.config.ffn_hidden_size if self.config.gated_linear_unit: ffn_hidden_size *= 2 @@ -93,7 +92,7 @@ def __init__( ) def forward(self, hidden_states): - + """Perform the forward pass through the MLP block.""" # [s, b, 4 * h/p] intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) @@ -149,19 +148,26 @@ def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets): # We must split the tensor into 2 parts, each sharded separately. # This requires a ShardedTensorFactory which `chunk`s during saving # and `cat`s during loading - tp_rank = parallel_state.get_tensor_model_parallel_rank() - tp_size = parallel_state.get_tensor_model_parallel_world_size() + swiglu_shard_axis = 0 prepend_axis_num = len(sharded_offsets) original_shape = original_sh_ten.local_shape original_numel = int(np.prod(original_shape)) + local_axis_size = original_shape[swiglu_shard_axis] + assert ( + original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] % local_axis_size == 0 + ) + rank_offset = ( + original_sh_ten.global_offset[swiglu_shard_axis + prepend_axis_num] // local_axis_size + ) + axis_frag = original_sh_ten.axis_fragmentations[swiglu_shard_axis + prepend_axis_num] @torch.no_grad() def sh_ten_build_fn( key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice] ): - offset_w = (swiglu_shard_axis + prepend_axis_num, tp_rank, tp_size * 2) - offset_v = (swiglu_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2) + offset_w = (swiglu_shard_axis + prepend_axis_num, rank_offset, axis_frag * 2) + offset_v = (swiglu_shard_axis + prepend_axis_num, rank_offset + axis_frag, axis_frag * 2) if flattened_range is None: tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis) return [ diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 6f9b24d39c..6b1bbe7d5f 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -22,6 +22,13 @@ import torch from packaging.version import Version as PkgVersion +try: + from torch.distributed._tensor import DTensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedTensor @@ -36,6 +43,23 @@ _te_version = None +def get_torch_version(): + """Get pytorch version from __version__; if not available use pip's. Use caching.""" + + def get_torch_version_str(): + import torch + + if hasattr(torch, '__version__'): + return str(torch.__version__) + else: + return version("torch") + + global _torch_version + if _torch_version is None: + _torch_version = PkgVersion(get_torch_version_str()) + return _torch_version + + def get_te_version(): """Get TE version from __version__; if not available use pip's. Use caching.""" @@ -368,21 +392,39 @@ def make_tp_sharded_tensor_for_checkpoint( Optionally, can provide offsets which prepend new dimensions to the tensor. """ - prepend_axis_num = len(prepend_offsets) + new_offsets = [] + tp_rank = parallel_state.get_tensor_model_parallel_rank() + dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + tp_size = parallel_state.get_tensor_model_parallel_world_size() + dp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) + dp_replica_id = parallel_state.get_data_parallel_rank(with_context_parallel=True) + + new_offsets.append((tp_axis + prepend_axis_num, tp_rank, tp_size)) + + if HAVE_DTENSOR and isinstance(tensor, DTensor): + # TP + FSDP2 sharding + dp_replica_id = 0 + tensor = tensor._local_tensor + + if tp_axis == 0: + # both FSDP2 and TP shards axis 0 + # default MCore uses tp-cp-ep-dp-pp + # FSDP2 is compatibile with TP, CP + new_offsets[0] = (prepend_axis_num, tp_rank * dp_size + dp_rank, tp_size * dp_size) + else: + # FSDP2 shards axis 0 and TP shards some other axis + new_offsets.append((prepend_axis_num, dp_rank, dp_size)) + if replica_id is None: - replica_id = (0, 0, parallel_state.get_data_parallel_rank(with_context_parallel=True)) + replica_id = (0, 0, dp_replica_id) return ShardedTensor.from_rank_offsets( key, tensor, *prepend_offsets, - ( - tp_axis + prepend_axis_num, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ), + *new_offsets, replica_id=replica_id, prepend_axis_num=prepend_axis_num, **kwargs, @@ -397,23 +439,48 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_ prepend_axis_num = len(prepend_offsets) + new_offsets = [] + dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + dp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) + dp_replica_id = parallel_state.get_data_parallel_rank(with_context_parallel=True) + + if HAVE_DTENSOR and isinstance(tensor, DTensor): + # FSDP2 sharding + dp_replica_id = 0 + tensor = tensor._local_tensor + new_offsets.append((prepend_axis_num, dp_rank, dp_size)) + if replica_id is None: - replica_id = ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_data_parallel_rank(with_context_parallel=True), - ) + replica_id = (0, parallel_state.get_tensor_model_parallel_rank(), dp_replica_id) return ShardedTensor.from_rank_offsets( key, tensor, *prepend_offsets, + *new_offsets, replica_id=replica_id, prepend_axis_num=prepend_axis_num, **kwargs, ) +def to_local_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor: + """Returns the local shard of the given tensor if it is a DTensor.""" + with torch.no_grad(): + return tensor.to_local() if HAVE_DTENSOR and isinstance(tensor, DTensor) else tensor + + +def get_data_parallel_group_if_dtensor( + tensor: Union[torch.Tensor, "DTensor"], data_parallel_group: "ProcessGroup" = None +) -> Optional["ProcessGroup"]: + """Gets the data parallel group of the given tensor if it is a DTensor.""" + if HAVE_DTENSOR and isinstance(tensor, DTensor): + current_group = tensor.device_mesh.get_group() + assert data_parallel_group is None or current_group == data_parallel_group + return current_group + return None + + def prepare_input_tensors_for_wgrad_compute(grad_output, all_gathered_input): """Ensure grad_output is stored in a contiguous buffer.""" # Doing gather + slicing during the NeMo forward pass can make this tensor diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index e034a32153..5791aecb04 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -9,6 +9,8 @@ import os import torch import types +import warnings +from packaging.version import Version as PkgVersion import torch.nn.functional as F @@ -214,9 +216,6 @@ def validate_args(args, defaults={}): args.pipeline_model_parallel_size -= args.encoder_pipeline_model_parallel_size assert args.pipeline_model_parallel_size > 0 - if args.tp_comm_overlap: - assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' - # Deprecated arguments assert args.batch_size is None, '--batch-size argument is no longer ' \ 'valid, use --micro-batch-size instead' @@ -304,6 +303,24 @@ def validate_args(args, defaults={}): 'Must use --overlap-param-gather with --overlap-grad-reduce' assert not args.use_legacy_models, \ '--overlap-param-gather only supported with MCore models' + + if getattr(args, "use_torch_fsdp2", False): + assert get_torch_version() >= PkgVersion("2.4"), \ + 'FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.' + assert args.pipeline_model_parallel_size == 1, \ + '--use-torch-fsdp2 is not supported with pipeline parallelism' + assert args.expert_model_parallel_size == 1, \ + '--use-torch-fsdp2 is not supported with expert parallelism' + assert not args.use_distributed_optimizer, \ + "--use-torch-fsdp2 is not supported with MCore's distributed optimizer" + assert not args.gradient_accumulation_fusion, \ + '--use-torch-fsdp2 is not supported with gradient accumulation fusion' + assert args.ckpt_format == 'torch_dist', \ + '--use-torch-fsdp2 requires --ckpt-format torch_dist' + assert args.untie_embeddings_and_output_weights, \ + '--use-torch-fsdp2 requires --untie-embeddings-and-output-weights' + assert not args.fp16, \ + '--use-torch-fsdp2 not supported with fp16 yet' if args.overlap_param_gather_with_optimizer_step: assert args.use_distributed_optimizer, \ @@ -500,12 +517,24 @@ def validate_args(args, defaults={}): # to avoid change in numerics when # sequence_parallelism is enabled. if args.tensor_model_parallel_size == 1: + if args.sequence_parallel: + warnings.warn("Disabling sequence parallelism because tensor model parallelism is disabled") args.sequence_parallel = False + if args.tp_comm_overlap: + assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if args.sequence_parallel: args.async_tensor_model_parallel_allreduce = False + if getattr(args, "use_torch_fsdp2", False): + warnings.warn( + "Using sequence parallelism with FSDP2 together. Try not to using them " + "together since they require different CUDA_MAX_CONNECTIONS settings " + "for best performance. sequence parallelism requires setting the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 while FSDP2 " + "requires not setting CUDA_DEVICE_MAX_CONNECTIONS=1 for better parallelization.") if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if args.sequence_parallel: @@ -1143,6 +1172,10 @@ def _add_training_args(parser): dest='use_pytorch_profiler') group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], help='Global ranks to profile.') + group.add_argument('--record-memory-history', action="store_true", default=False, + help='Record memory history in last rank.') + group.add_argument('--memory-snapshot-path', type=str, default="snapshot.pickle", + help='Specifies where to dump the memory history pickle.') group.add_argument('--tp-comm-overlap', action='store_true', help='Enables the ' ' overlap of Tensor parallel communication and GEMM kernels.') group.add_argument('--tp-comm-overlap-cfg', type=str, default=None, @@ -1605,6 +1638,9 @@ def _add_distributed_args(parser): 'affects the encoder embedding.)') group.add_argument('--use-distributed-optimizer', action='store_true', help='Use distributed optimizer.') + group.add_argument('--use-torch-fsdp2', action='store_true', + help="Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel." + "It is still not in a stable release stage, and may therefore contain bugs or other potential issues.") group.add_argument('--context-parallel-size', type=int, default=1, help='Degree of context parallelism.') group.add_argument('--nccl-communicator-config-path', type=str, default=None, diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index efe98e94e9..1bf86672c3 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -992,11 +992,15 @@ def fix_fp8_params_lose_precision_when_loading_dist_ckpt(state_dict): def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, - ft_client=None, checkpointing_context=None): + ft_client=None, checkpointing_context=None, skip_load_to_model_and_opt=False): """Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of parameters and buffers in model. + skip_load_to_model_and_opt (bool): whether to call `load_state_dict` + for :attr:`model` and :attr:`optimizer`. In case of running FSDP2 + or other torch features that uses DTensor in state dict, the tensors + are already loaded in-place by `_load_base_checkpoint`. """ args = get_args() load_dir = getattr(args, load_arg) @@ -1164,12 +1168,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # Model. strict = False if args.retro_add_retriever else strict - if len(model) == 1: - model[0].load_state_dict(state_dict['model'], strict=strict) - else: - for i in range(len(model)): - mpu.set_virtual_pipeline_model_parallel_rank(i) - model[i].load_state_dict(state_dict['model%d' % i], strict=strict) + if not skip_load_to_model_and_opt: + if len(model) == 1: + model[0].load_state_dict(state_dict['model'], strict=strict) + else: + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + model[i].load_state_dict(state_dict['model%d' % i], strict=strict) # Fix up query/key/value matrix ordering if needed. checkpoint_version = get_checkpoint_version() @@ -1180,7 +1185,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri if not release and not args.finetune and not args.no_load_optim: try: # Load state dict. - if optimizer is not None: + if not skip_load_to_model_and_opt and optimizer is not None: optimizer.load_state_dict(state_dict['optimizer']) # Load distributed optimizer's custom parameter state. diff --git a/megatron/training/training.py b/megatron/training/training.py index 0984ee376f..400450782d 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -32,6 +32,13 @@ from megatron.legacy.model import Float16Module from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed import DistributedDataParallel as DDP +try: + from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP + + HAVE_FSDP2 = True +except ImportError: + HAVE_FSDP2 = False + from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig @@ -541,6 +548,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap fp8_meta.amax_history[0][fp8_meta_index] = 0 if wrap_with_ddp: + if getattr(args, "use_torch_fsdp2", False): + assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0" + DP = torch_FSDP + else: + DP = DDP + config = get_model_config(model[0]) kwargs = {} @@ -554,9 +567,9 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ddp_config = DistributedDataParallelConfig(**kwargs) overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False) - model = [DDP(config, - ddp_config, - model_chunk, + model = [DP(config=config, + ddp_config=ddp_config, + module=model_chunk, # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step) @@ -687,7 +700,8 @@ def setup_model_and_optimizer(model_provider_func, args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( model, optimizer, opt_param_scheduler, - ft_client=ft_integration.get_rank_monitor_client(), checkpointing_context=checkpointing_context) + ft_client=ft_integration.get_rank_monitor_client(), checkpointing_context=checkpointing_context, + skip_load_to_model_and_opt=HAVE_FSDP2 and getattr(args, "use_torch_fsdp2", False)) timers('load-checkpoint').stop(barrier=True) timers.log(['load-checkpoint']) one_logger and one_logger.log_metrics({ @@ -885,6 +899,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) if writer and (iteration % args.tensorboard_log_interval == 0): + if args.record_memory_history and is_last_rank(): + snapshot = torch.cuda.memory._snapshot() + from pickle import dump + with open(args.memory_snapshot_path , 'wb') as f: + dump(snapshot, f) + if wandb_writer: wandb_writer.log({'samples vs steps': args.consumed_train_samples}, iteration) diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 1950584a00..60480bf6b4 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -37,11 +37,15 @@ from megatron.core import DistributedDataParallel as DDP from megatron.core import mpu from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate +from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor from megatron.legacy.model import Float16Module from megatron.legacy.model.module import param_is_not_shared - -ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) +try: + from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, Float16Module) +except ImportError: + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): @@ -66,17 +70,23 @@ def calc_params_l2_norm(model): model = [model] # Remove duplicate params. params_data = [] - for model_ in model: - for param in model_.parameters(): + data_parallel_group = None + + for model_chunk in model: + for i, param in enumerate(model_chunk.parameters()): + data_parallel_group = get_data_parallel_group_if_dtensor(param, data_parallel_group) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if not (param.requires_grad and is_not_tp_duplicate): continue + assert is_not_tp_duplicate if mpu.get_expert_model_parallel_rank() > 0: if not getattr(param, 'allreduce', True): assert param_is_not_shared(param) + param = to_local_if_dtensor(param) params_data.append(param.data.float() if args.bf16 else param.data) else: if param_is_not_shared(param): + param = to_local_if_dtensor(param) params_data.append(param.data.float() if args.bf16 else param.data) # Calculate norm @@ -88,6 +98,12 @@ def calc_params_l2_norm(model): False # no per-parameter norm ) norm_2 = norm * norm + + if data_parallel_group is not None: + torch.distributed.all_reduce(norm_2, + op=torch.distributed.ReduceOp.SUM, + group=data_parallel_group) + if mpu.get_expert_model_parallel_world_size() == 1: # Sum across all model-parallel GPUs(tensor + pipeline). torch.distributed.all_reduce(norm_2, diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 3b7f8db012..4fc4a79809 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -53,6 +53,14 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat args = get_args() use_te = args.transformer_impl == "transformer_engine" + if args.record_memory_history: + torch.cuda.memory._record_memory_history(True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + + # record stack information for the trace events + trace_alloc_record_context=True) + print_rank_0('building GPT model ...') # Experimental loading arguments from yaml if args.yaml_cfg is not None: diff --git a/tests/functional_tests/jet_recipes/gpt.yaml b/tests/functional_tests/jet_recipes/gpt.yaml index bd79f05759..2d722adeef 100644 --- a/tests/functional_tests/jet_recipes/gpt.yaml +++ b/tests/functional_tests/jet_recipes/gpt.yaml @@ -66,6 +66,7 @@ products: - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_untied_dgx_a100_1N8G + # - gpt3_mr_mcore_te_tp2_pp1_fsdp2_resume_torch_dist_dgx_a100_1N8G # torch >= 2.4.0 - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_tunable_overlap_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_groupedGEMM_dgx_a100_1N8G @@ -113,6 +114,7 @@ products: n_repeat: [5] test_case: - gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp1_dist_optimizer_overlap_grad_reduce_param_gather + # - gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp1_fsdp2_resume_torch_dist_te # torch >= 2.4.0 - gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather - gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp2 - gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp2_resume_torch_dist diff --git a/tests/functional_tests/test_cases/gpt/gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp1_fsdp2_resume_torch_dist_te/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp1_fsdp2_resume_torch_dist_te/model_config.yaml new file mode 100644 index 0000000000..da4f2c131d --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp1_pp1_fsdp2_resume_torch_dist_te/model_config.yaml @@ -0,0 +1,52 @@ +ENV_VARS: + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 + N_REPEATS: 10 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 100 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 50 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 1 + --use-torch-fsdp2: true + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --no-async-tensor-model-parallel-allreduce: true + --use-checkpoint-opt_param-scheduler: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --fp16: true + --apply-query-key-layer-scaling: true +TEST_TYPE: ckpt-resume \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_fsdp2_resume_torch_dist_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_fsdp2_resume_torch_dist_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..912b9bb533 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_fsdp2_resume_torch_dist_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,52 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 100 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 50 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 2 + --pipeline-model-parallel-size: 1 + --use-torch-fsdp2: true + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --no-async-tensor-model-parallel-allreduce: true + --attention-softmax-in-fp32: true + --use-checkpoint-opt_param-scheduler: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true +TEST_TYPE: ckpt-resume \ No newline at end of file diff --git a/tests/unit_tests/dist_checkpointing/test_local.py b/tests/unit_tests/dist_checkpointing/test_local.py index e4dfc6f8e8..69919fedae 100644 --- a/tests/unit_tests/dist_checkpointing/test_local.py +++ b/tests/unit_tests/dist_checkpointing/test_local.py @@ -61,7 +61,8 @@ def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) - def test_sharded_tensors(self, tp, pp): + @pytest.mark.parametrize(('use_torch_fsdp2'), [True, False]) + def test_sharded_tensors(self, tp, pp, use_torch_fsdp2): Utils.initialize_model_parallel(tp, pp) num_floating_point_operations_so_far = 0 model, optimizer = setup_model_and_optimizer(1, tp, pp) @@ -73,6 +74,7 @@ def test_sharded_tensors(self, tp, pp): mock_args = SimpleNamespace() mock_args.no_save_optim = False mock_args.no_save_rng = True + mock_args.use_torch_fsdp2 = use_torch_fsdp2 # Test save_local state_dict = generate_state_dict( mock_args, diff --git a/tests/unit_tests/dist_checkpointing/test_serialization.py b/tests/unit_tests/dist_checkpointing/test_serialization.py index 8ad6bd95e7..63d2c68725 100644 --- a/tests/unit_tests/dist_checkpointing/test_serialization.py +++ b/tests/unit_tests/dist_checkpointing/test_serialization.py @@ -8,6 +8,14 @@ import torch from torch.distributed.checkpoint import CheckpointException as PyTCheckpointingException +try: + from torch.distributed import DeviceMesh + from torch.distributed._tensor import DTensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + from megatron.core import parallel_state from megatron.core.dist_checkpointing import ShardedTensor, load, save from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config @@ -42,6 +50,16 @@ def test_single_process_save_load(self, tmp_path_dist_ckpt): ), } + if HAVE_DTENSOR: + mesh = DeviceMesh.from_group( + parallel_state.get_data_parallel_group(with_context_parallel=True), "cuda" + ) + sharded_state_dict['sd_keyD'] = ShardedTensor.from_rank_offsets( + 'keyD', + DTensor.from_local(torch.ones(3, 5, 7), mesh)._local_tensor, + replica_id=Utils.rank, + ) + # sync=True to make sure other ranks wait for rank 0 to finish creating directory. with TempNamedDir( tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True @@ -56,6 +74,9 @@ def test_single_process_save_load(self, tmp_path_dist_ckpt): assert not (ckpt_dir / 'keyC').exists() assert not (ckpt_dir / 'sd_keyA').is_dir() + if HAVE_DTENSOR: + assert (ckpt_dir / 'keyD').is_dir() + load_ssd = { 'load_sd_keyA': ShardedTensor.from_rank_offsets( 'keyA', torch.ones(2, 4), replica_id=Utils.rank diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py index edd3039604..50677f0958 100644 --- a/tests/unit_tests/dist_checkpointing/utils.py +++ b/tests/unit_tests/dist_checkpointing/utils.py @@ -116,6 +116,7 @@ def init_basic_mock_args(args, tp, pp, bf16=True): args.encoder_tensor_model_parallel_size = 0 args.encoder_pipeline_model_parallel_size = 0 args.enable_ft_package = False + args.use_torch_fsdp2 = False return args