diff --git a/bionemo-recipes/models/llama3/README.md b/bionemo-recipes/models/llama3/README.md index ca0417c977..ef62f13b7e 100644 --- a/bionemo-recipes/models/llama3/README.md +++ b/bionemo-recipes/models/llama3/README.md @@ -19,7 +19,7 @@ The Llama-3 implementation natively supports the following TransformerEngine-pro | **Export to HuggingFace checkpoints** | ✅ Supported | | **KV-cache inference** | ✅ Supported (including beam search) | | **Context Parallelism** | ✅ Supported | -| **Tensor Parallelism** | 🚧 Under development | +| **Tensor Parallelism** | ✅ Supported | Refer to [BioNeMo Recipes](../../recipes/llama3_native_te/README.md) for more details on how to use these features to accelerate model training and inference with native PyTorch training loops. diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index c88a8f9154..9bbedf32aa 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -25,6 +25,8 @@ import transformer_engine.common.recipe import transformer_engine.pytorch import transformers +from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module +from torch.distributed.tensor.placement_types import Replicate from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention.inference import PagedKVCacheManager from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding @@ -52,6 +54,9 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + tensor_parallel: bool = False + sequence_parallel: bool = False + tp_size: int = 1 def __init__( self, @@ -142,6 +147,8 @@ def __init__( config: LlamaConfig, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + nvte_tp_mesh: torch.distributed.DeviceMesh | None = None, + nvte_weight_mesh: torch.distributed.DeviceMesh | None = None, ): """Initialize the NVLlama model. @@ -149,6 +156,8 @@ def __init__( config: The configuration of the model. fp8_recipe: The FP8 recipe for the model. fp4_recipe: The FP4 recipe for the model. + nvte_tp_mesh: TP DeviceMesh for the model. + nvte_weight_mesh: Weight-sharding DeviceMesh for the model. """ super().__init__(config) self.config = config @@ -156,6 +165,8 @@ def __init__( self.vocab_size = config.vocab_size self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe + self.tp_mesh = nvte_tp_mesh + self.weight_mesh = nvte_weight_mesh if self.config.layer_precision is None: if fp8_recipe is not None and fp4_recipe is not None: @@ -174,6 +185,27 @@ def __init__( self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP. + if config.tensor_parallel: + assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh." + assert self.tp_mesh.size() == config.tp_size, ( + f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) " + f"does not match configured TP size ({config.tp_size}).", + ) + # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding + # during HuggingFace post-init, this will automatically convert the TELinear head + # weight into a DTensor with the correct sharding placements prior to FSDP2 + # fully_shard(), and no need to call TELinear.set_device_mesh(). + parallelize_module( + self.embed_tokens, + self.tp_mesh, + # Un-sharded output activations for compatible input to TETransformer. + # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1) + # RowwiseParallel doesn't support output_layouts=Replicate() with + # torch.compile: https://github.com/pytorch/torchtitan/issues/534 + ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()), + ) + def _init_method(x): torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) @@ -201,6 +233,11 @@ def _init_method(x): device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=_init_method, output_layer_init_method=_init_method, + set_parallel_mode=config.tensor_parallel, + sequence_parallel=config.sequence_parallel, + tp_size=config.tp_size, + tp_mesh=self.tp_mesh, + weight_mesh=self.weight_mesh, ) ] @@ -211,6 +248,8 @@ def _init_method(x): dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", ) + # Norm modules are non-Base TransformerEngine modules that require a manual call for TP. + self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh) # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original # LlamaRotaryEmbedding. @@ -387,6 +426,8 @@ def __init__( config, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + nvte_tp_mesh: torch.distributed.DeviceMesh | None = None, + nvte_weight_mesh: torch.distributed.DeviceMesh | None = None, ): """Initialize the NVLlamaForCausalLM model. @@ -394,10 +435,21 @@ def __init__( config: The configuration of the model. fp8_recipe: The FP8 recipe for the model. fp4_recipe: The FP4 recipe for the model. + nvte_tp_mesh: TP DeviceMesh for the model. + nvte_weight_mesh: Weight-sharding DeviceMesh for the model. """ super().__init__(config) - self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + self.model = NVLlamaModel( + config, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + nvte_tp_mesh=nvte_tp_mesh, + nvte_weight_mesh=nvte_weight_mesh, + ) + self.config = config self.vocab_size = config.vocab_size + self.tp_mesh = nvte_tp_mesh + self.weight_mesh = nvte_weight_mesh with transformer_engine.pytorch.quantized_model_init(enabled=False): self.lm_head = transformer_engine.pytorch.Linear( config.hidden_size, @@ -406,9 +458,25 @@ def __init__( params_dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + parallel_mode="row" if config.tensor_parallel else None, + # This scatters your output, not ever needed for final layer. + # Will all-reduce the output instead, as required. + sequence_parallel=False, + tp_size=config.tp_size, ) - - # Initialize weights and apply final processing + if config.tensor_parallel: + if config.tie_word_embeddings: + # Head weights have already been tied to the embedding weights. + # Just set the tensor parallel group for TE. + # No parameter quantization either, so no need for weight_mesh. + self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group()) + else: + # Head weights are not tied to the embedding weights. Need to + # wrap the LM head weight as a DTensor with TE. + # No parameter quantization either, so no need for weight_mesh. + self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh) + + # Initialize weights and apply final processing. Ties weights. self.post_init() def forward( @@ -461,6 +529,14 @@ def forward( # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + if self.config.tensor_parallel: + # If using TP, shard your activation across the TP group, + # to support row-wise tensor parallelism in the LM head. + # Use ... to support both BSHD (3D) and THD (2D) hidden states. + tp_rank = self.tp_mesh.get_local_rank() + tp_stride = hidden_states.shape[-1] // self.config.tp_size + hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride] + with transformer_engine.pytorch.autocast(enabled=False): if hidden_states.ndim == 3: logits = self.lm_head(hidden_states[:, slice_indices, :]) diff --git a/bionemo-recipes/models/llama3/requirements.txt b/bionemo-recipes/models/llama3/requirements.txt index ec6a547cb8..a16bb00438 100644 --- a/bionemo-recipes/models/llama3/requirements.txt +++ b/bionemo-recipes/models/llama3/requirements.txt @@ -1,5 +1,5 @@ lm-eval # For testing torch torchao!=0.14.0 -transformer_engine[pytorch] +transformer_engine[pytorch] @ git+https://github.com/cspades/TransformerEngine.git@7e0d3a9b9cca4243bcf9f233f79fcb0c09795e3f transformers diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py index 4b98cd36a3..9085803916 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py @@ -135,7 +135,7 @@ def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path): [ "torchrun", "--nproc_per_node=2", - "train_fsdp2_cp.py", + "train_fsdp2_nd_parallel.py", "--config-name", "L0_sanity_cp", "num_train_steps=4", diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index 12d4313941..0ce3a0260d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -87,7 +87,7 @@ def main(args: DictConfig) -> float | None: ) if args.use_fp32_master_weights: - raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_cp.py instead.") + raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_nd_parallel.py instead.") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". # Note: token_dropout is set to False because it's not compatible with context parallelism. diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index d99a3001c0..e3e6b55d09 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -18,7 +18,7 @@ bionemo-framework repository. You can download a zipped directory of this folder | Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism | | ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ | -| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | +| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅: Supported
🚧: Under development
@@ -94,7 +94,7 @@ This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Conte - [Distributed Data Parallel (DDP)](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html), shown in `train_ddp.py` - [Fully Sharded Data Parallel 2 (FSDP2)](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html), shown in `train_fsdp2.py` -- FSDP2 with Context Parallelism, shown in `train_fsdp2_cp.py` +- FSDP2 with Context Parallelism, shown in `train_fsdp2_nd_parallel.py` ## Commands to Launch Training @@ -193,13 +193,30 @@ python train_fsdp2.py --config-name L0_sanity \ ### Context Parallel Training Context parallelism splits each sequence across multiple GPUs along the sequence dimension, enabling training with very -long sequences. Use `train_fsdp2_cp.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context +long sequences. Use `train_fsdp2_nd_parallel.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context parallelism ranks. Works with both BSHD (no padding) and THD (padding) input formats. Only TE models are supported. ```bash -torchrun --nproc_per_node=4 train_fsdp2_cp.py --config-name L0_sanity_cp cp_size=2 +torchrun --nproc_per_node=4 train_fsdp2_nd_parallel.py --config-name L0_sanity_cp cp_size=2 ``` +### Tensor Parallel Training + +Tensor parallelism shards model activations and weights along the hidden dimension across multiple GPUs, and is compatible with +context parallelism and FSDP2 in TransformerEngine. Use `train_fsdp2_nd_parallel` with the `L2_sanity_nd` config and optionally +set `tp_size` and `cp_size` such that the product of these sizes is less than or equal to the number of GPUs. When the FSDP and +TP sharding dimensions coincide, FSDP2 will automatically create `_StridedShard` placements to ensure that TP shards are the +result of the all-gather during training. Only TE models are supported, and refer to the `set_device_mesh` API on TE modules for +more information about how `DeviceMesh` and `DTensor` interact with TransformerEngine. + +```bash +torchrun --nproc_per_node=4 train_fsdp2_nd_parallel.py --config-name L2_sanity_nd cp_size=2 tp_size=2 +``` + +Note that TransformerEngine TP is compatible with DTensor-based TP. For instance, the `torch.nn.Embedding` is weight-tied to the +Llama LM Head, but the embedding layer is tensor-parallelized with `parallelize_module` while the head is tensor-parallelized with +TransformerEngine. + ## Downloading Pre-Training Data For Offline Training This recipe is configured to use genomic sequences. The default configuration uses a local test file diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py index 2dc5d10dcf..b75fd00549 100644 --- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py @@ -233,9 +233,28 @@ class AppState(Stateful): epoch: int = 0 def state_dict(self): - """Get the state dict for the model, optimizer, scheduler, and step.""" + """Get the state dict for the model, optimizer, scheduler, and step. + + This factory both retrieves the model state dictionary when saving + checkpoints and initializes a destination for the state read from + DCP checkpoint files when loading checkpoints. + """ model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) - model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")} + for fqn in list(model_state_dict.keys()): + # Get the model parameter. + model_param = model_state_dict[fqn] + if isinstance(model_param, DTensor): + model_param = model_param.to_local() + if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]: + # Empty model parameter. Clear the associated optimizer state + # when initializing the optimizer state upon DCP load, because + # empty optimizer state DTensors are not checkpointed with DCP, + # yet get_state_dict / _init_optim_state produce empty Tensors. + # TransformerEngine uses empty Tensors for dummy Parameters. + optimizer_state_dict["state"][fqn] = {} + if fqn.endswith("_extra_state"): + # Evict `_extra_state` quantization data from model checkpoint. + model_state_dict.pop(fqn) return { "model": model_state_dict, "optim": optimizer_state_dict, @@ -245,12 +264,18 @@ def state_dict(self): } def load_state_dict(self, state_dict: dict): - """Load the state dict for the model, optimizer, scheduler, and step.""" + """Load the state dict for the model, optimizer, scheduler, and step. + + Given the checkpoint-loaded state_dict, set the state of the model, + optimizer, scheduler, step, and epoch to the values in state_dict. + """ set_state_dict( self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"], + # Non-strict checkpoint loading ignores empty optimizer states, + # skips loading non-FP8 checkpoint weights (e.g. _extra_state). options=StateDictOptions(strict=False), ) self.scheduler.load_state_dict(state_dict["scheduler"]) diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_tp.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_tp.yaml new file mode 100644 index 0000000000..601fdf097c --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_tp.yaml @@ -0,0 +1,15 @@ +defaults: + - L0_sanity + - _self_ + +tp_size: 2 # Tensor Parallel sharding factor +cp_size: 1 + +use_sequence_packing: false + +config_kwargs: + attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware. + self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs. + tensor_parallel: true # Tensor Parallelism for TE + sequence_parallel: false # Sequence parallelism for LayerNorm on TP ranks. + tp_size: ${tp_size} # Tensor Parallel Size diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml new file mode 100644 index 0000000000..b31d8c9500 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml @@ -0,0 +1,27 @@ +defaults: + - L0_sanity + - _self_ + +tp_size: 2 # TP Sharding +cp_size: 2 # FSDP-CP Sharding + +dataset: + # CP2 * (8 for FP8 Activations, 16 for FP8 Parameters) + pad_sequences_to_be_divisible_by: 32 + +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +checkpoint: + ckpt_dir: ./fsdp_nd_ckpts + save_final_model: true + +config_kwargs: + attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware. + self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs. + tensor_parallel: true # Tensor Parallelism for TE + sequence_parallel: true # Sequence parallelism for LayerNorm on TP ranks. + tp_size: ${tp_size} # Tensor Parallel Size diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index 9302a0758d..fdd4f0f0b8 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -12,6 +12,10 @@ use_meta_device: true # We leave this off by default since we don't see much of a performance improvement with TE layers. use_torch_compile: false +# Default parallelism sizes. +tp_size: 1 +cp_size: 1 + use_sequence_packing: false dataset: diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index 713a26b8be..0a7cde0037 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -31,6 +31,8 @@ import transformer_engine.common.recipe import transformer_engine.pytorch import transformers +from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module +from torch.distributed.tensor.placement_types import Replicate from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention.inference import PagedKVCacheManager from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding @@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + tensor_parallel: bool = False + sequence_parallel: bool = False + tp_size: int = 1 def __init__( self, @@ -148,6 +153,8 @@ def __init__( config: LlamaConfig, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + nvte_tp_mesh: torch.distributed.DeviceMesh | None = None, + nvte_weight_mesh: torch.distributed.DeviceMesh | None = None, ): """Initialize the NVLlama model. @@ -155,6 +162,8 @@ def __init__( config: The configuration of the model. fp8_recipe: The FP8 recipe for the model. fp4_recipe: The FP4 recipe for the model. + nvte_tp_mesh: TP DeviceMesh for the model. + nvte_weight_mesh: Weight-sharding DeviceMesh for the model. """ super().__init__(config) self.config = config @@ -162,6 +171,8 @@ def __init__( self.vocab_size = config.vocab_size self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe + self.tp_mesh = nvte_tp_mesh + self.weight_mesh = nvte_weight_mesh if self.config.layer_precision is None: if fp8_recipe is not None and fp4_recipe is not None: @@ -180,6 +191,27 @@ def __init__( self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP. + if config.tensor_parallel: + assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh." + assert self.tp_mesh.size() == config.tp_size, ( + f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) " + f"does not match configured TP size ({config.tp_size}).", + ) + # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding + # during HuggingFace post-init, this will automatically convert the TELinear head + # weight into a DTensor with the correct sharding placements prior to FSDP2 + # fully_shard(), and no need to call TELinear.set_device_mesh(). + parallelize_module( + self.embed_tokens, + self.tp_mesh, + # Un-sharded output activations for compatible input to TETransformer. + # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1) + # RowwiseParallel doesn't support output_layouts=Replicate() with + # torch.compile: https://github.com/pytorch/torchtitan/issues/534 + ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()), + ) + def _init_method(x): torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) @@ -207,6 +239,11 @@ def _init_method(x): device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=_init_method, output_layer_init_method=_init_method, + set_parallel_mode=config.tensor_parallel, + sequence_parallel=config.sequence_parallel, + tp_size=config.tp_size, + tp_mesh=self.tp_mesh, + weight_mesh=self.weight_mesh, ) ] @@ -217,6 +254,8 @@ def _init_method(x): dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", ) + # Norm modules are non-Base TransformerEngine modules that require a manual call for TP. + self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh) # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original # LlamaRotaryEmbedding. @@ -393,6 +432,8 @@ def __init__( config, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + nvte_tp_mesh: torch.distributed.DeviceMesh | None = None, + nvte_weight_mesh: torch.distributed.DeviceMesh | None = None, ): """Initialize the NVLlamaForCausalLM model. @@ -400,10 +441,21 @@ def __init__( config: The configuration of the model. fp8_recipe: The FP8 recipe for the model. fp4_recipe: The FP4 recipe for the model. + nvte_tp_mesh: TP DeviceMesh for the model. + nvte_weight_mesh: Weight-sharding DeviceMesh for the model. """ super().__init__(config) - self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + self.model = NVLlamaModel( + config, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + nvte_tp_mesh=nvte_tp_mesh, + nvte_weight_mesh=nvte_weight_mesh, + ) + self.config = config self.vocab_size = config.vocab_size + self.tp_mesh = nvte_tp_mesh + self.weight_mesh = nvte_weight_mesh with transformer_engine.pytorch.quantized_model_init(enabled=False): self.lm_head = transformer_engine.pytorch.Linear( config.hidden_size, @@ -412,9 +464,25 @@ def __init__( params_dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + parallel_mode="row" if config.tensor_parallel else None, + # This scatters your output, not ever needed for final layer. + # Will all-reduce the output instead, as required. + sequence_parallel=False, + tp_size=config.tp_size, ) - - # Initialize weights and apply final processing + if config.tensor_parallel: + if config.tie_word_embeddings: + # Head weights have already been tied to the embedding weights. + # Just set the tensor parallel group for TE. + # No parameter quantization either, so no need for weight_mesh. + self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group()) + else: + # Head weights are not tied to the embedding weights. Need to + # wrap the LM head weight as a DTensor with TE. + # No parameter quantization either, so no need for weight_mesh. + self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh) + + # Initialize weights and apply final processing. Ties weights. self.post_init() def forward( @@ -467,6 +535,14 @@ def forward( # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + if self.config.tensor_parallel: + # If using TP, shard your activation across the TP group, + # to support row-wise tensor parallelism in the LM head. + # Use ... to support both BSHD (3D) and THD (2D) hidden states. + tp_rank = self.tp_mesh.get_local_rank() + tp_stride = hidden_states.shape[-1] // self.config.tp_size + hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride] + with transformer_engine.pytorch.autocast(enabled=False): if hidden_states.ndim == 3: logits = self.lm_head(hidden_states[:, slice_indices, :]) diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt index 073d9b39e3..ad7e7fb061 100644 --- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt +++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt @@ -5,7 +5,7 @@ torchao!=0.14.0 torchdata torchmetrics tqdm -transformer_engine[pytorch] +transformer_engine[pytorch] @ git+https://github.com/cspades/TransformerEngine.git@7e0d3a9b9cca4243bcf9f233f79fcb0c09795e3f transformers wandb zstandard diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py index 0cbaa64673..dda13749a1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py @@ -706,7 +706,7 @@ def test_cp_dataloader(tokenizer_path): cp_mesh = device_mesh["cp"] - # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py + # Create the context-parallel dataloader directly following the pattern in train_fsdp2_nd_parallel.py if cp_mesh.get_local_rank() == 0: train_dataloader, _ = create_thd_dataloader( distributed_config=dist_config, @@ -799,7 +799,7 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path): cp_mesh = device_mesh["cp"] - # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py + # Create the context-parallel dataloader directly following the pattern in train_fsdp2_nd_parallel.py if cp_mesh.get_local_rank() == 0: train_dataloader, _ = create_thd_dataloader( distributed_config=dist_config, diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py index 889f57345b..8f8b22f5f5 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py @@ -40,7 +40,7 @@ from train_ddp import main as main_ddp from train_fsdp2 import main as main_fsdp2 -from train_fsdp2_cp import main as main_fsdp2_cp +from train_fsdp2_nd_parallel import main as main_fsdp2_cp os.environ["WANDB_DISABLED"] = "true" @@ -329,7 +329,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(r _run_multi_process_checkpoint_test( recipe_path, tmp_path, - "train_fsdp2_cp.py", + "train_fsdp2_nd_parallel.py", ckpt_subdir_name="train_fsdp2", extra_overrides=["checkpoint.async_save=false", "cp_size=2"], is_ddp=False, diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index 89e85068de..5a3e1994d0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -24,7 +24,7 @@ from train_ddp import main as main_ddp from train_fsdp2 import main as main_fsdp2 -from train_fsdp2_cp import main as main_fsdp2_cp +from train_fsdp2_nd_parallel import main as main_fsdp2_cp # TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. @@ -480,6 +480,61 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path): assert stats_log.stat().st_size > 0, "Statistics log file is empty" +def test_sanity_nd_parallel_tp1_bshd(tmp_path, recipe_path): + """Test ND-parallel training with tensor_parallel=True and tp_size=1 (trivial TP group), BSHD. + + This test validates that all TP code paths in NVLlamaModel and NVLlamaForCausalLM execute + correctly with a single-rank TP mesh: + - parallelize_module on embed_tokens (ColwiseParallel) + - TransformerLayer TP mode flags + - lm_head row-parallel mode and set_tensor_parallel_group + - Hidden-state activation slicing in NVLlamaForCausalLM.forward + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity_tp", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "num_train_steps=10", + "tp_size=1", + "checkpoint.resume_from_checkpoint=false", + ], + ) + + final_loss = main_fsdp2_cp(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" + + +def test_sanity_nd_parallel_tp1_sequence_parallel_bshd(tmp_path, recipe_path): + """Test ND-parallel training with tensor_parallel=True, sequence_parallel=True, tp_size=1, BSHD. + + Validates that the sequence-parallel RMSNorm (set_device_mesh on the final norm) does not + break forward/backward even when the TP group is a single rank. + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity_tp", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "num_train_steps=10", + "tp_size=1", + "config_kwargs.sequence_parallel=true", + "checkpoint.resume_from_checkpoint=false", + ], + ) + + final_loss = main_fsdp2_cp(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" + + @requires_fp8 def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): """Test that FP8 stats logging works with FSDP2.""" diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py index 0993f5fb46..cac6b11e13 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py @@ -199,7 +199,7 @@ def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path): "torchrun", "--standalone", "--nproc_per_node=2", - "train_fsdp2_cp.py", + "train_fsdp2_nd_parallel.py", "--config-name", "L0_sanity_cp", "num_train_steps=10", @@ -223,7 +223,7 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path): "torchrun", "--standalone", "--nproc_per_node=2", - "train_fsdp2_cp.py", + "train_fsdp2_nd_parallel.py", "--config-name", "L0_sanity_cp", "num_train_steps=10", @@ -238,6 +238,111 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path): ) +@requires_multi_gpu +def test_multi_gpu_train_te_fsdp2_tp_bshd(tmp_path, recipe_path): + """Test FSDP2 with tensor parallelism on 2 GPUs using BSHD input format. + + Validates: + - The 1-D TP device mesh (dp=1, cp=1, tp=2) is created and used correctly + - Embedding weights are ColwiseParallel-sharded across 2 TP ranks + - TransformerLayer TP mode shards QKV/FFN weights across ranks + - Row-wise parallel LM head with hidden-state slicing before forward + """ + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_fsdp2_nd_parallel.py", + "--config-name", + "L0_sanity_tp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + ], + recipe_path, + ) + + +@requires_multi_gpu +@requires_datacenter_hardware +def test_multi_gpu_train_te_fsdp2_tp_thd(tmp_path, recipe_path): + """Test FSDP2 with tensor parallelism on 2 GPUs using THD (sequence-packed) input format. + + Validates: + - TP=2, CP=1 with sequence-packing / THD attention format + - _unpad_input / _pad_input round-trip works alongside TP activation sharding + - padding_causal mask type is compatible with row-wise parallel LM head + """ + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_fsdp2_nd_parallel.py", + "--config-name", + "L0_sanity_tp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", + ], + recipe_path, + ) + + +@requires_multi_gpu +def test_multi_gpu_train_te_fsdp2_tp_sequence_parallel_bshd(tmp_path, recipe_path): + """Test FSDP2 with tensor parallelism + sequence parallelism on 2 GPUs, BSHD. + + Validates that sequence parallelism (LayerNorm activations sharded across TP ranks) + works alongside standard tensor parallelism without errors. + """ + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_fsdp2_nd_parallel.py", + "--config-name", + "L0_sanity_tp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "config_kwargs.sequence_parallel=true", + ], + recipe_path, + ) + + +@requires_multi_gpu +def test_multi_gpu_train_te_fsdp2_tp_bshd_with_checkpointing(tmp_path, recipe_path): + """Test FSDP2 TP training on 2 GPUs with checkpoint saving. + + Validates: + - Sharded FSDP2 checkpoints are written correctly while TP is active + - The expected checkpoint directory structure is present after training + """ + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_fsdp2_nd_parallel.py", + "--config-name", + "L0_sanity_tp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.save_every_n_steps=5", + "checkpoint.resume_from_checkpoint=false", + ], + recipe_path, + ) + + ckpt_dir = tmp_path / "train_fsdp2" + assert ckpt_dir.exists(), f"Checkpoint directory not created: {ckpt_dir}" + assert (ckpt_dir / "step_5").exists(), "Checkpoint at step 5 not found" + + nsys_available = subprocess.run(["which", "nsys"], check=False, capture_output=True).returncode == 0 diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index da19daa2a7..ade90cef1b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -19,7 +19,7 @@ the memory of a single GPU. Supports both TE-accelerated (NVLlamaForCausalLM) and standard HuggingFace (LlamaForCausalLM) models. -For very long sequences, use ``train_fsdp2_cp.py`` which adds Context Parallelism on top of FSDP2. +For very long sequences, use ``train_fsdp2_nd_parallel.py`` which adds Context Parallelism on top of FSDP2. """ import gc diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py similarity index 81% rename from bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py rename to bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py index eaf1a1b39f..23cebe6318 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py @@ -13,14 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""FSDP2 with Context Parallelism training script for Llama 3 with TransformerEngine. +"""FSDP2 with Tensor & Context Parallelism training script for Llama 3 with TransformerEngine. -Combines Fully Sharded Data Parallel v2 with Context Parallelism (CP), where each sequence is -split across multiple GPUs along the sequence dimension. This is useful for training with very long -sequences that do not fit into a single GPU's memory even with FSDP2 alone. Only supports -TE-accelerated models (NVLlamaForCausalLM). +Combines Fully Sharded Data Parallel v2 with Tensor Parallelism (TP) and Context Parallelism (CP). +In Context Parallelism, each sequence is split across multiple GPUs along the sequence dimension, +which is useful for training on extremely long sequences that exhaust activation memory. +In Tensor Parallelism, weights and activations are sharded on the hidden dim across multiple GPUs, +which is useful for sharding model weights and activations unlike FSDP which only shards weights. +Only supports TE-accelerated models (NVLlamaForCausalLM). -For standard FSDP2 training without context parallelism, use ``train_fsdp2.py`` instead. +For standard FSDP2 training without N-D parallelism, use ``train_fsdp2.py`` instead. """ import gc @@ -59,7 +61,7 @@ @hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2") def main(args: DictConfig) -> float | None: - """Train Llama3 with TE layers using FSDP2 with Context Parallelism. + """Train Llama3 with TE layers using FSDP2, CP, and TP. Returns: float: The loss value for the final batch. @@ -73,8 +75,8 @@ def main(args: DictConfig) -> float | None: device_mesh = init_device_mesh( "cuda", - mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size), - mesh_dim_names=("dp", "cp"), + mesh_shape=(dist_config.world_size // (args.cp_size * args.tp_size), args.cp_size, args.tp_size), + mesh_dim_names=("dp", "cp", "tp"), ) logger.info("Created device mesh: %s", device_mesh) @@ -94,11 +96,22 @@ def main(args: DictConfig) -> float | None: config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) with torch.device("meta") if args.use_meta_device else nullcontext(): - model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + model = NVLlamaForCausalLM( + config, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + # Only pass DeviceMesh to TransformerEngine if using Tensor Parallelism + # or if your DeviceMesh has multiple weight-sharding dimensions. + nvte_tp_mesh=device_mesh["tp"] if config.tensor_parallel else None, + # nvte_weight_mesh is only required for Float8CurrentScaling parameters. + nvte_weight_mesh=device_mesh["dp", "cp", "tp"]._flatten("weight_mesh") if config.tensor_parallel else None, + ) logger.info("Initialized Model:\n%s", model) # --- Distributed Wrapping (FSDP2 + CP) --- + + # Create a flattened mesh for FSDP2-CP sharding. This will shard the model across both the DP and CP ranks. cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. @@ -107,7 +120,7 @@ def main(args: DictConfig) -> float | None: fully_shard(layer, mesh=cp_dp_mesh) fully_shard(model, mesh=cp_dp_mesh) - # Attach the CP group to the model. + # Attach the CP ProcessGroup to the TransformerEngine model. for layer in model.model.layers: layer.set_context_parallel_group( device_mesh["cp"].get_group(), @@ -136,9 +149,12 @@ def main(args: DictConfig) -> float | None: logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) - # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP) - # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline. - if device_mesh["cp"].get_local_rank() == 0: + # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and TP) ranks. + # This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline. + cp_tp_mesh = device_mesh["cp", "tp"]._flatten(mesh_dim_name="cp_tp") + if cp_tp_mesh.get_local_rank() == 0: + # We only create the dataloader on CP-TP Rank 0 and pass it to a ContextParallelDataLoaderWrapper + # that will shard, replicate, and distribute the data across the flattened CP and TP group. if args.use_sequence_packing: train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) else: @@ -155,8 +171,8 @@ def main(args: DictConfig) -> float | None: train_dataloader = None dataset_or_sampler = None - # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0. - train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"]) + # Deliver CP-sharded replicates to a flattened CP-TP mesh. + train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_tp_mesh) # --- Checkpoint Resume --- ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None @@ -169,7 +185,6 @@ def main(args: DictConfig) -> float | None: ckpt_path=ckpt_path, dist_config=dist_config, dataloader=train_dataloader, - process_group=cp_dp_mesh.get_group(), ) logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: @@ -234,7 +249,6 @@ def main(args: DictConfig) -> float | None: epoch=epoch, dist_config=dist_config, dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, - process_group=cp_dp_mesh.get_group(), max_checkpoints=args.checkpoint.max_checkpoints, async_save=args.checkpoint.async_save, ) diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py index 713a26b8be..0a7cde0037 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py @@ -31,6 +31,8 @@ import transformer_engine.common.recipe import transformer_engine.pytorch import transformers +from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module +from torch.distributed.tensor.placement_types import Replicate from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention.inference import PagedKVCacheManager from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding @@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + tensor_parallel: bool = False + sequence_parallel: bool = False + tp_size: int = 1 def __init__( self, @@ -148,6 +153,8 @@ def __init__( config: LlamaConfig, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + nvte_tp_mesh: torch.distributed.DeviceMesh | None = None, + nvte_weight_mesh: torch.distributed.DeviceMesh | None = None, ): """Initialize the NVLlama model. @@ -155,6 +162,8 @@ def __init__( config: The configuration of the model. fp8_recipe: The FP8 recipe for the model. fp4_recipe: The FP4 recipe for the model. + nvte_tp_mesh: TP DeviceMesh for the model. + nvte_weight_mesh: Weight-sharding DeviceMesh for the model. """ super().__init__(config) self.config = config @@ -162,6 +171,8 @@ def __init__( self.vocab_size = config.vocab_size self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe + self.tp_mesh = nvte_tp_mesh + self.weight_mesh = nvte_weight_mesh if self.config.layer_precision is None: if fp8_recipe is not None and fp4_recipe is not None: @@ -180,6 +191,27 @@ def __init__( self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP. + if config.tensor_parallel: + assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh." + assert self.tp_mesh.size() == config.tp_size, ( + f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) " + f"does not match configured TP size ({config.tp_size}).", + ) + # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding + # during HuggingFace post-init, this will automatically convert the TELinear head + # weight into a DTensor with the correct sharding placements prior to FSDP2 + # fully_shard(), and no need to call TELinear.set_device_mesh(). + parallelize_module( + self.embed_tokens, + self.tp_mesh, + # Un-sharded output activations for compatible input to TETransformer. + # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1) + # RowwiseParallel doesn't support output_layouts=Replicate() with + # torch.compile: https://github.com/pytorch/torchtitan/issues/534 + ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()), + ) + def _init_method(x): torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) @@ -207,6 +239,11 @@ def _init_method(x): device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=_init_method, output_layer_init_method=_init_method, + set_parallel_mode=config.tensor_parallel, + sequence_parallel=config.sequence_parallel, + tp_size=config.tp_size, + tp_mesh=self.tp_mesh, + weight_mesh=self.weight_mesh, ) ] @@ -217,6 +254,8 @@ def _init_method(x): dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", ) + # Norm modules are non-Base TransformerEngine modules that require a manual call for TP. + self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh) # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original # LlamaRotaryEmbedding. @@ -393,6 +432,8 @@ def __init__( config, fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + nvte_tp_mesh: torch.distributed.DeviceMesh | None = None, + nvte_weight_mesh: torch.distributed.DeviceMesh | None = None, ): """Initialize the NVLlamaForCausalLM model. @@ -400,10 +441,21 @@ def __init__( config: The configuration of the model. fp8_recipe: The FP8 recipe for the model. fp4_recipe: The FP4 recipe for the model. + nvte_tp_mesh: TP DeviceMesh for the model. + nvte_weight_mesh: Weight-sharding DeviceMesh for the model. """ super().__init__(config) - self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + self.model = NVLlamaModel( + config, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + nvte_tp_mesh=nvte_tp_mesh, + nvte_weight_mesh=nvte_weight_mesh, + ) + self.config = config self.vocab_size = config.vocab_size + self.tp_mesh = nvte_tp_mesh + self.weight_mesh = nvte_weight_mesh with transformer_engine.pytorch.quantized_model_init(enabled=False): self.lm_head = transformer_engine.pytorch.Linear( config.hidden_size, @@ -412,9 +464,25 @@ def __init__( params_dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + parallel_mode="row" if config.tensor_parallel else None, + # This scatters your output, not ever needed for final layer. + # Will all-reduce the output instead, as required. + sequence_parallel=False, + tp_size=config.tp_size, ) - - # Initialize weights and apply final processing + if config.tensor_parallel: + if config.tie_word_embeddings: + # Head weights have already been tied to the embedding weights. + # Just set the tensor parallel group for TE. + # No parameter quantization either, so no need for weight_mesh. + self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group()) + else: + # Head weights are not tied to the embedding weights. Need to + # wrap the LM head weight as a DTensor with TE. + # No parameter quantization either, so no need for weight_mesh. + self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh) + + # Initialize weights and apply final processing. Ties weights. self.post_init() def forward( @@ -467,6 +535,14 @@ def forward( # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + if self.config.tensor_parallel: + # If using TP, shard your activation across the TP group, + # to support row-wise tensor parallelism in the LM head. + # Use ... to support both BSHD (3D) and THD (2D) hidden states. + tp_rank = self.tp_mesh.get_local_rank() + tp_stride = hidden_states.shape[-1] // self.config.tp_size + hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride] + with transformer_engine.pytorch.autocast(enabled=False): if hidden_states.ndim == 3: logits = self.lm_head(hidden_states[:, slice_indices, :]) diff --git a/bionemo-recipes/recipes/vit/train.py b/bionemo-recipes/recipes/vit/train.py index db7c410638..5df1348523 100644 --- a/bionemo-recipes/recipes/vit/train.py +++ b/bionemo-recipes/recipes/vit/train.py @@ -22,7 +22,7 @@ import torch import wandb from hydra.core.hydra_config import HydraConfig -from megatron_fsdp import fully_shard +from megatron_fsdp.fully_shard import fully_shard from omegaconf import OmegaConf from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR diff --git a/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml b/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml index 82b9a2df3d..8178316790 100644 --- a/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml +++ b/ci/lepton/model_convergence/configs/recipes/llama3_native_te.yaml @@ -87,7 +87,7 @@ products: job_name: "llama3-native-L0-sanity" # L0 sanity test with context parallelism - config: L0_sanity_cp - task_cmd: train_fsdp2_cp + task_cmd: train_fsdp2_nd_parallel cp_enabled: true wandb_name: "llama3_native__L0_sanity_cp__${now:%Y%m%d-%H%M%S}__${gitsha:}" job_name: "llama3-native-L0-sanity-cp"